add CUDNN impl for Min and Max (#118)

* add cudnn impl for Min and Max

* fix onnx _search_shape with output shape
This commit is contained in:
constroy Li 2023-08-22 16:19:29 +08:00 committed by GitHub
parent 1438f14a25
commit 1e91979c76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 1 deletions

View File

@ -41,6 +41,8 @@ class GraphHandlerObj {
Tensor mul(Tensor a, Tensor b, Tensor c);
Tensor div(Tensor a, Tensor b, Tensor c);
Tensor pow(Tensor a, Tensor b, Tensor c);
Tensor min(Tensor a, Tensor b, Tensor c);
Tensor max(Tensor a, Tensor b, Tensor c);
Tensor relu(Tensor x, Tensor y);
Tensor sigmoid(Tensor x, Tensor y);

View File

@ -348,6 +348,18 @@ class OnnxStub:
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Min":
tensors[node.output[0]] = self.handler.min(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Max":
tensors[node.output[0]] = self.handler.max(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Relu":
tensors[node.output[0]] = self.handler.relu(
tensors[node.input[0]],
@ -942,6 +954,17 @@ def _search_shape(model: ModelProto, name: str) -> List[int]:
),
None,
)
or next(
(
[
(d.dim_value if d.dim_value > 0 else 1)
for d in tensor.type.tensor_type.shape.dim
]
for tensor in model.graph.output
if tensor.name == name
),
None,
)
or next(
[int(d) for d in tensor.dims]
for tensor in model.graph.initializer

View File

@ -135,6 +135,8 @@ DEFINE_ELEMENT_WISE_METHOD(sub, Sub)
DEFINE_ELEMENT_WISE_METHOD(mul, Mul)
DEFINE_ELEMENT_WISE_METHOD(div, Div)
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
DEFINE_ELEMENT_WISE_METHOD(min, Minimum)
DEFINE_ELEMENT_WISE_METHOD(max, Maximum)
// see operators/unary.h
#define DEFINE_UNARY_METHOD(name, obj) \

View File

@ -98,6 +98,14 @@ class MulCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; }
};
class MinCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MIN; }
};
class MaxCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MAX; }
};
class ElementWiseCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
@ -137,6 +145,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn,
"Sub_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn,
"Mul_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Min, DataType::Float32, MinCudnn,
"Min_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn,
"Max_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
"Div_CUDA_Float32");

View File

@ -58,7 +58,12 @@ TEST(cuDNN_ElementWise, run) {
testElementWiseCudnn<DivObj>(
OneGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
testElementWiseCudnn<MinimumObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
testElementWiseCudnn<MaximumObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
testElementWiseCudnn<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
ExpectOutput{1, 1, 4, 27});
}