diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index fa704442..db29d1c6 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -41,6 +41,8 @@ class GraphHandlerObj { Tensor mul(Tensor a, Tensor b, Tensor c); Tensor div(Tensor a, Tensor b, Tensor c); Tensor pow(Tensor a, Tensor b, Tensor c); + Tensor min(Tensor a, Tensor b, Tensor c); + Tensor max(Tensor a, Tensor b, Tensor c); Tensor relu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 1a52a95b..ced54ae7 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -348,6 +348,18 @@ class OnnxStub: tensors[node.input[1]], tensors.get(node.output[0]), ) + elif node.op_type == "Min": + tensors[node.output[0]] = self.handler.min( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Max": + tensors[node.output[0]] = self.handler.max( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) elif node.op_type == "Relu": tensors[node.output[0]] = self.handler.relu( tensors[node.input[0]], @@ -942,6 +954,17 @@ def _search_shape(model: ModelProto, name: str) -> List[int]: ), None, ) + or next( + ( + [ + (d.dim_value if d.dim_value > 0 else 1) + for d in tensor.type.tensor_type.shape.dim + ] + for tensor in model.graph.output + if tensor.name == name + ), + None, + ) or next( [int(d) for d in tensor.dims] for tensor in model.graph.initializer diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index be7d5578..40fb42df 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -135,6 +135,8 @@ DEFINE_ELEMENT_WISE_METHOD(sub, Sub) DEFINE_ELEMENT_WISE_METHOD(mul, Mul) DEFINE_ELEMENT_WISE_METHOD(div, Div) DEFINE_ELEMENT_WISE_METHOD(pow, Pow) +DEFINE_ELEMENT_WISE_METHOD(min, Minimum) +DEFINE_ELEMENT_WISE_METHOD(max, Maximum) // see operators/unary.h #define DEFINE_UNARY_METHOD(name, obj) \ diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index dc57e8d1..99b586fb 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -98,6 +98,14 @@ class MulCudnn : public ElementWiseCudnn { cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; } }; +class MinCudnn : public ElementWiseCudnn { + cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MIN; } +}; + +class MaxCudnn : public ElementWiseCudnn { + cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MAX; } +}; + class ElementWiseCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { @@ -137,6 +145,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn, "Sub_cuDNN_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn, "Mul_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Min, DataType::Float32, MinCudnn, + "Min_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn, + "Max_cuDNN_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda, "Div_CUDA_Float32"); diff --git a/test/kernels/cuda/test_cuda_element_wise.cc b/test/kernels/cuda/test_cuda_element_wise.cc index 05872388..7b3711cb 100644 --- a/test/kernels/cuda/test_cuda_element_wise.cc +++ b/test/kernels/cuda/test_cuda_element_wise.cc @@ -58,7 +58,12 @@ TEST(cuDNN_ElementWise, run) { testElementWiseCudnn( OneGenerator(), Shape{1, 2, 2, 3}, ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); testElementWiseCudnn(IncrementalGenerator(), Shape{1, 2, 2, 1}, ExpectOutput{1, 1, 4, 27}); }