diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index c93a355f..1c79f51d 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -51,6 +51,7 @@ class GraphHandlerObj { Tensor softmax(Tensor x, Tensor y, int axis); Tensor abs(Tensor x, Tensor y); Tensor sqrt(Tensor x, Tensor y); + Tensor neg(Tensor x, Tensor y); Tensor shape(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y); Tensor flatten(Tensor s, Tensor y, int axis); diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 0f26c2e3..c538682a 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -3,14 +3,14 @@ #include "operators/unary.h" namespace infini { -// TODO(constroy): num should be size_t. -void softmax_kernel(float *input, float *output, int num); -void relu_kernel(float *input, float *output, int num); -void sigmoid_kernel(float *input, float *output, int num); -void tanh_kernel(float *input, float *output, int num); -void abs_kernel(float *input, float *output, int num); -void sqrt_kernel(float *input, float *output, int num); -void erf_kernel(float *input, float *output, int num); +void softmax_kernel(float *input, float *output, size_t num); +void relu_kernel(float *input, float *output, size_t num); +void sigmoid_kernel(float *input, float *output, size_t num); +void tanh_kernel(float *input, float *output, size_t num); +void abs_kernel(float *input, float *output, size_t num); +void sqrt_kernel(float *input, float *output, size_t num); +void neg_kernel(float *input, float *output, size_t num); +void erf_kernel(float *input, float *output, size_t num); void unary_kernel(const Operator &_op) { auto op = as(_op); @@ -30,6 +30,8 @@ void unary_kernel(const Operator &_op) { abs_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Sqrt) sqrt_kernel(inputData, outputData, num); + else if (op->getOpType() == OpType::Neg) + neg_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Erf) erf_kernel(inputData, outputData, num); else diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 149c284f..96d4778d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -403,6 +403,11 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Neg": + tensors[node.output[0]] = self.handler.neg( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "Shape": tensors[node.output[0]] = self.handler.shape( tensors[node.input[0]], @@ -916,6 +921,7 @@ class OnnxStub: backend.OpTypeId.PRelu, backend.OpTypeId.Sqrt, backend.OpTypeId.Erf, + backend.OpTypeId.Neg, ]: ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpTypeId.Flatten: diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 6d041ed2..2d614b48 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -243,6 +243,12 @@ class TestStringMethods(unittest.TestCase): y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) abs = make_node("Abs", ["x"], ["y"], name="abs") make_and_import_model(make_graph([abs], "abs", [x], [y])) + + def test_neg(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + neg = make_node("Neg", ["x"], ["y"], name="neg") + make_and_import_model(make_graph([neg], "neg", [x], [y])) def test_identity(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 7267fddf..bbd73e10 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -159,6 +159,7 @@ DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(sqrt, Sqrt) +DEFINE_UNARY_METHOD(neg, Neg) DEFINE_UNARY_METHOD(shape, Shape) DEFINE_UNARY_METHOD(erf, Erf) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 92feba2a..b515164f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -100,6 +100,7 @@ void export_values(py::module &m) { .VALUE(OpType, Dropout) .VALUE(OpType, Cast) .VALUE(OpType, Sqrt) + .VALUE(OpType, Neg) .VALUE(OpType, Expand) .VALUE(OpType, Erf) .VALUE(OpType, Where) @@ -444,6 +445,7 @@ void init_graph_builder(py::module &m) { .def("softmax", &Handler::softmax, policy::move) .def("abs", &Handler::abs, policy::move) .def("sqrt", &Handler::sqrt, policy::move) + .def("neg", &Handler::neg, policy::move) .def("shape", &Handler::shape, policy::move) .def("identity", &Handler::identity, policy::move) .def("flatten", &Handler::flatten, policy::move) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index e559c909..15025115 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -64,6 +64,10 @@ template class NaiveErf : public NativeUnary { T doCompute(T val) const override { return std::erf(val); } }; +template class NaiveNeg : public NativeUnary { + T doCompute(T val) const override { return -val; } +}; + template class Clip : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { @@ -103,6 +107,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt, "sqrtNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf, "erfNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg, + "negNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, NaiveSoftmax, "softmaxNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 897e2c77..cb53bd80 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, "Abs_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, "Sqrt_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Neg, DataType::Float32, UnaryCuda, + "Neg_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda, "Erf_CUDA_Float32"); diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 695762b4..061ac63d 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -8,7 +8,7 @@ constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _softmax_kernel1(float *input, float *output, int n) { +__global__ void _softmax_kernel1(float *input, float *output, size_t n) { float sum = 0.0f; for (size_t i = 0; i < n; ++i) { sum += pow(E_CONSTANT, input[i]); @@ -16,106 +16,121 @@ __global__ void _softmax_kernel1(float *input, float *output, int n) { *output = sum; } -__global__ void _softmax_kernel2(float *input, float *output, int n) { +__global__ void _softmax_kernel2(float *input, float *output, size_t n) { float sum = *output; - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < n; i += stride) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { output[i] = pow(E_CONSTANT, input[i]) / sum; } } -__global__ void _relu_kernel(float *input, float *output, int n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < n; i += stride) { +__global__ void _relu_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { output[i] = max(input[i], float(0)); } } -__global__ void _sigmoid_kernel(float *input, float *output, int n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < n; i += stride) { +__global__ void _sigmoid_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { output[i] = 1 / (1 + pow(E_CONSTANT, -input[i])); } } -__global__ void _tanh_kernel(float *input, float *output, int n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < n; i += stride) { +__global__ void _tanh_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { output[i] = (pow(E_CONSTANT, input[i]) - pow(E_CONSTANT, -input[i])) / (pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i])); } } -__global__ void _abs_kernel(float *input, float *output, int n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < n; i += stride) { +__global__ void _abs_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { output[i] = input[i] < 0 ? -input[i] : input[i]; } } -__global__ void _sqrt_kernel(float *input, float *output, int n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < n; i += stride) { +__global__ void _sqrt_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { output[i] = sqrt(input[i]); } } -__global__ void _erf_kernel(float *input, float *output, int n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; +__global__ void _erf_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; for (int i = index; i < n; i += stride) { output[i] = erf(input[i]); } } +template +__global__ void _neg_kernel(T *input, T *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { + output[i] = -input[i]; + } +} + namespace infini { -void softmax_kernel(float *input, float *output, int num) { +void softmax_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _softmax_kernel1<<<1, 1>>>(input, output, num); _softmax_kernel2<<>>(input, output, num); } -void relu_kernel(float *input, float *output, int num) { +void relu_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _relu_kernel<<>>(input, output, num); } -void sigmoid_kernel(float *input, float *output, int num) { +void sigmoid_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _sigmoid_kernel<<>>(input, output, num); } -void tanh_kernel(float *input, float *output, int num) { +void tanh_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _tanh_kernel<<>>(input, output, num); } -void abs_kernel(float *input, float *output, int num) { +void abs_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _abs_kernel<<>>(input, output, num); } -void sqrt_kernel(float *input, float *output, int num) { +void sqrt_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _sqrt_kernel<<>>(input, output, num); } -void erf_kernel(float *input, float *output, int num) { +void erf_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _erf_kernel<<>>(input, output, num); } +void neg_kernel(float *input, float *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _neg_kernel<<>>(input, output, num); +} }; // namespace infini diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 5d9f24ec..2f828fab 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -46,6 +46,7 @@ TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); // more shapes testUnary(IncrementalGenerator(), Shape{13});