Add Neg operator and kernel (#152)

* Add Neg operator and kernel

* handle neg in to_onnx

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
PanZezhong1725 2023-10-10 10:54:56 +08:00 committed by GitHub
parent 7a9fcd93b2
commit 7600fe688c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 85 additions and 43 deletions

View File

@ -51,6 +51,7 @@ class GraphHandlerObj {
Tensor softmax(Tensor x, Tensor y, int axis); Tensor softmax(Tensor x, Tensor y, int axis);
Tensor abs(Tensor x, Tensor y); Tensor abs(Tensor x, Tensor y);
Tensor sqrt(Tensor x, Tensor y); Tensor sqrt(Tensor x, Tensor y);
Tensor neg(Tensor x, Tensor y);
Tensor shape(Tensor x, Tensor y); Tensor shape(Tensor x, Tensor y);
Tensor identity(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y);
Tensor flatten(Tensor s, Tensor y, int axis); Tensor flatten(Tensor s, Tensor y, int axis);

View File

@ -3,14 +3,14 @@
#include "operators/unary.h" #include "operators/unary.h"
namespace infini { namespace infini {
// TODO(constroy): num should be size_t. void softmax_kernel(float *input, float *output, size_t num);
void softmax_kernel(float *input, float *output, int num); void relu_kernel(float *input, float *output, size_t num);
void relu_kernel(float *input, float *output, int num); void sigmoid_kernel(float *input, float *output, size_t num);
void sigmoid_kernel(float *input, float *output, int num); void tanh_kernel(float *input, float *output, size_t num);
void tanh_kernel(float *input, float *output, int num); void abs_kernel(float *input, float *output, size_t num);
void abs_kernel(float *input, float *output, int num); void sqrt_kernel(float *input, float *output, size_t num);
void sqrt_kernel(float *input, float *output, int num); void neg_kernel(float *input, float *output, size_t num);
void erf_kernel(float *input, float *output, int num); void erf_kernel(float *input, float *output, size_t num);
void unary_kernel(const Operator &_op) { void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
@ -30,6 +30,8 @@ void unary_kernel(const Operator &_op) {
abs_kernel(inputData, outputData, num); abs_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sqrt) else if (op->getOpType() == OpType::Sqrt)
sqrt_kernel(inputData, outputData, num); sqrt_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Neg)
neg_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Erf) else if (op->getOpType() == OpType::Erf)
erf_kernel(inputData, outputData, num); erf_kernel(inputData, outputData, num);
else else

View File

@ -403,6 +403,11 @@ class OnnxStub:
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type == "Neg":
tensors[node.output[0]] = self.handler.neg(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Shape": elif node.op_type == "Shape":
tensors[node.output[0]] = self.handler.shape( tensors[node.output[0]] = self.handler.shape(
tensors[node.input[0]], tensors[node.input[0]],
@ -916,6 +921,7 @@ class OnnxStub:
backend.OpTypeId.PRelu, backend.OpTypeId.PRelu,
backend.OpTypeId.Sqrt, backend.OpTypeId.Sqrt,
backend.OpTypeId.Erf, backend.OpTypeId.Erf,
backend.OpTypeId.Neg,
]: ]:
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Flatten: elif ty == backend.OpTypeId.Flatten:

View File

@ -244,6 +244,12 @@ class TestStringMethods(unittest.TestCase):
abs = make_node("Abs", ["x"], ["y"], name="abs") abs = make_node("Abs", ["x"], ["y"], name="abs")
make_and_import_model(make_graph([abs], "abs", [x], [y])) make_and_import_model(make_graph([abs], "abs", [x], [y]))
def test_neg(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
neg = make_node("Neg", ["x"], ["y"], name="neg")
make_and_import_model(make_graph([neg], "neg", [x], [y]))
def test_identity(self): def test_identity(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])

View File

@ -159,6 +159,7 @@ DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(tanh, Tanh)
DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(abs, Abs)
DEFINE_UNARY_METHOD(sqrt, Sqrt) DEFINE_UNARY_METHOD(sqrt, Sqrt)
DEFINE_UNARY_METHOD(neg, Neg)
DEFINE_UNARY_METHOD(shape, Shape) DEFINE_UNARY_METHOD(shape, Shape)
DEFINE_UNARY_METHOD(erf, Erf) DEFINE_UNARY_METHOD(erf, Erf)

View File

@ -100,6 +100,7 @@ void export_values(py::module &m) {
.VALUE(OpType, Dropout) .VALUE(OpType, Dropout)
.VALUE(OpType, Cast) .VALUE(OpType, Cast)
.VALUE(OpType, Sqrt) .VALUE(OpType, Sqrt)
.VALUE(OpType, Neg)
.VALUE(OpType, Expand) .VALUE(OpType, Expand)
.VALUE(OpType, Erf) .VALUE(OpType, Erf)
.VALUE(OpType, Where) .VALUE(OpType, Where)
@ -444,6 +445,7 @@ void init_graph_builder(py::module &m) {
.def("softmax", &Handler::softmax, policy::move) .def("softmax", &Handler::softmax, policy::move)
.def("abs", &Handler::abs, policy::move) .def("abs", &Handler::abs, policy::move)
.def("sqrt", &Handler::sqrt, policy::move) .def("sqrt", &Handler::sqrt, policy::move)
.def("neg", &Handler::neg, policy::move)
.def("shape", &Handler::shape, policy::move) .def("shape", &Handler::shape, policy::move)
.def("identity", &Handler::identity, policy::move) .def("identity", &Handler::identity, policy::move)
.def("flatten", &Handler::flatten, policy::move) .def("flatten", &Handler::flatten, policy::move)

View File

@ -64,6 +64,10 @@ template <typename T> class NaiveErf : public NativeUnary<T> {
T doCompute(T val) const override { return std::erf(val); } T doCompute(T val) const override { return std::erf(val); }
}; };
template <typename T> class NaiveNeg : public NativeUnary<T> {
T doCompute(T val) const override { return -val; }
};
template <typename T> class Clip : public CpuKernelWithoutConfig { template <typename T> class Clip : public CpuKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *context) const override { const RuntimeObj *context) const override {
@ -103,6 +107,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
"sqrtNaive_CPU_float32"); "sqrtNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>, REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
"erfNaive_CPU_float32"); "erfNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg<float>,
"negNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32"); NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,

View File

@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
"Abs_CUDA_Float32"); "Abs_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
"Sqrt_CUDA_Float32"); "Sqrt_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Neg, DataType::Float32, UnaryCuda,
"Neg_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda, REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda,
"Erf_CUDA_Float32"); "Erf_CUDA_Float32");

View File

@ -8,7 +8,7 @@ constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _softmax_kernel1(float *input, float *output, int n) { __global__ void _softmax_kernel1(float *input, float *output, size_t n) {
float sum = 0.0f; float sum = 0.0f;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
sum += pow(E_CONSTANT, input[i]); sum += pow(E_CONSTANT, input[i]);
@ -16,106 +16,121 @@ __global__ void _softmax_kernel1(float *input, float *output, int n) {
*output = sum; *output = sum;
} }
__global__ void _softmax_kernel2(float *input, float *output, int n) { __global__ void _softmax_kernel2(float *input, float *output, size_t n) {
float sum = *output; float sum = *output;
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (size_t i = index; i < n; i += stride) {
output[i] = pow(E_CONSTANT, input[i]) / sum; output[i] = pow(E_CONSTANT, input[i]) / sum;
} }
} }
__global__ void _relu_kernel(float *input, float *output, int n) { __global__ void _relu_kernel(float *input, float *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (size_t i = index; i < n; i += stride) {
output[i] = max(input[i], float(0)); output[i] = max(input[i], float(0));
} }
} }
__global__ void _sigmoid_kernel(float *input, float *output, int n) { __global__ void _sigmoid_kernel(float *input, float *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (size_t i = index; i < n; i += stride) {
output[i] = 1 / (1 + pow(E_CONSTANT, -input[i])); output[i] = 1 / (1 + pow(E_CONSTANT, -input[i]));
} }
} }
__global__ void _tanh_kernel(float *input, float *output, int n) { __global__ void _tanh_kernel(float *input, float *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (size_t i = index; i < n; i += stride) {
output[i] = (pow(E_CONSTANT, input[i]) - pow(E_CONSTANT, -input[i])) / output[i] = (pow(E_CONSTANT, input[i]) - pow(E_CONSTANT, -input[i])) /
(pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i])); (pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i]));
} }
} }
__global__ void _abs_kernel(float *input, float *output, int n) { __global__ void _abs_kernel(float *input, float *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (size_t i = index; i < n; i += stride) {
output[i] = input[i] < 0 ? -input[i] : input[i]; output[i] = input[i] < 0 ? -input[i] : input[i];
} }
} }
__global__ void _sqrt_kernel(float *input, float *output, int n) { __global__ void _sqrt_kernel(float *input, float *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (size_t i = index; i < n; i += stride) {
output[i] = sqrt(input[i]); output[i] = sqrt(input[i]);
} }
} }
__global__ void _erf_kernel(float *input, float *output, int n) { __global__ void _erf_kernel(float *input, float *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) { for (int i = index; i < n; i += stride) {
output[i] = erf(input[i]); output[i] = erf(input[i]);
} }
} }
template <typename T>
__global__ void _neg_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
output[i] = -input[i];
}
}
namespace infini { namespace infini {
void softmax_kernel(float *input, float *output, int num) { void softmax_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_softmax_kernel1<<<1, 1>>>(input, output, num); _softmax_kernel1<<<1, 1>>>(input, output, num);
_softmax_kernel2<<<gridsize, blocksize>>>(input, output, num); _softmax_kernel2<<<gridsize, blocksize>>>(input, output, num);
} }
void relu_kernel(float *input, float *output, int num) { void relu_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_relu_kernel<<<gridsize, blocksize>>>(input, output, num); _relu_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void sigmoid_kernel(float *input, float *output, int num) { void sigmoid_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num); _sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void tanh_kernel(float *input, float *output, int num) { void tanh_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_tanh_kernel<<<gridsize, blocksize>>>(input, output, num); _tanh_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void abs_kernel(float *input, float *output, int num) { void abs_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_abs_kernel<<<gridsize, blocksize>>>(input, output, num); _abs_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void sqrt_kernel(float *input, float *output, int num) { void sqrt_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num); _sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void erf_kernel(float *input, float *output, int num) { void erf_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_erf_kernel<<<gridsize, blocksize>>>(input, output, num); _erf_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void neg_kernel(float *input, float *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_neg_kernel<<<gridsize, blocksize>>>(input, output, num);
}
}; // namespace infini }; // namespace infini

View File

@ -46,6 +46,7 @@ TEST(cuDNN_Unary, run) {
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
// more shapes // more shapes
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13}); testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});