diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index d20b4ca7..9af10f6e 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -475,7 +475,18 @@ class OnnxStub: for (i, it) in enumerate(op.outputs()) ] if ty == backend.OpType.Conv: - raise Exception("TODO") + ph, pw, sh, sw, dh, dw = backend.conv_attrs_of(op) + ctx.push_node( + make_node( + ty.name, + inputs, + outputs, + name, + pads=[ph, pw], + strides=[sh, sw], + dilations=[dh, dw], + ) + ) elif ty == backend.OpType.Matmul: ctx.push_node(make_node("MatMul", inputs, outputs, name)) elif ty == backend.OpType.BatchNorm: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index d7230f42..49164e07 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/batch_norm.h" #include "operators/concat.h" +#include "operators/conv.h" #include "operators/gather.h" #include "operators/reduce_mean.h" #include "operators/reshape.h" @@ -100,6 +101,20 @@ static int tensor_dtype(Tensor t) { static Ref cuda_runtime() { return make_ref(); } #endif +static std::tuple conv_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Conv); + auto conv = dynamic_cast(op.get()); + return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(), + conv->getDw(), conv->getSh(), conv->getSw()); +} + +static std::tuple batch_norm_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::BatchNorm); + auto batchnorm = dynamic_cast(op.get()); + return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(), + batchnorm->getTraining()); +} + static int concat_axis_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Concat); return dynamic_cast(op.get())->getDim(); @@ -121,25 +136,19 @@ static Shape reshape_shape_of(Operator op) { return dynamic_cast(op.get())->getShape(); } -static std::tuple batch_norm_attrs_of(Operator op) { - IT_ASSERT(op->getOpType() == OpType::BatchNorm); - auto batchnorm = dynamic_cast(op.get()); - return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(), - batchnorm->getTraining()); -} - void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &CpuRuntimeObj::getInstance) #ifdef USE_CUDA .FUNCTION(cuda_runtime) #endif + .FUNCTION(conv_attrs_of) + .FUNCTION(batch_norm_attrs_of) .FUNCTION(tensor_dtype) .FUNCTION(reshape_shape_of) .FUNCTION(concat_axis_of) .FUNCTION(gather_axis_of) - .FUNCTION(reduce_mean_axes_of) - .FUNCTION(batch_norm_attrs_of); + .FUNCTION(reduce_mean_axes_of); #undef FUNCTION }