feat: 导出 conv 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-27 09:15:45 +08:00
parent afed749b74
commit 8a871c3773
2 changed files with 30 additions and 10 deletions

View File

@ -475,7 +475,18 @@ class OnnxStub:
for (i, it) in enumerate(op.outputs()) for (i, it) in enumerate(op.outputs())
] ]
if ty == backend.OpType.Conv: if ty == backend.OpType.Conv:
raise Exception("TODO") ph, pw, sh, sw, dh, dw = backend.conv_attrs_of(op)
ctx.push_node(
make_node(
ty.name,
inputs,
outputs,
name,
pads=[ph, pw],
strides=[sh, sw],
dilations=[dh, dw],
)
)
elif ty == backend.OpType.Matmul: elif ty == backend.OpType.Matmul:
ctx.push_node(make_node("MatMul", inputs, outputs, name)) ctx.push_node(make_node("MatMul", inputs, outputs, name))
elif ty == backend.OpType.BatchNorm: elif ty == backend.OpType.BatchNorm:

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/reduce_mean.h" #include "operators/reduce_mean.h"
#include "operators/reshape.h" #include "operators/reshape.h"
@ -100,6 +101,20 @@ static int tensor_dtype(Tensor t) {
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); } static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
#endif #endif
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Conv);
auto conv = dynamic_cast<const ConvObj *>(op.get());
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
conv->getDw(), conv->getSh(), conv->getSw());
}
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
batchnorm->getTraining());
}
static int concat_axis_of(Operator op) { static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat); IT_ASSERT(op->getOpType() == OpType::Concat);
return dynamic_cast<const ConcatObj *>(op.get())->getDim(); return dynamic_cast<const ConcatObj *>(op.get())->getDim();
@ -121,25 +136,19 @@ static Shape reshape_shape_of(Operator op) {
return dynamic_cast<const ReshapeObj *>(op.get())->getShape(); return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
} }
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
batchnorm->getTraining());
}
void export_functions(py::module &m) { void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME) #define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &CpuRuntimeObj::getInstance) m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
#ifdef USE_CUDA #ifdef USE_CUDA
.FUNCTION(cuda_runtime) .FUNCTION(cuda_runtime)
#endif #endif
.FUNCTION(conv_attrs_of)
.FUNCTION(batch_norm_attrs_of)
.FUNCTION(tensor_dtype) .FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of) .FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of) .FUNCTION(concat_axis_of)
.FUNCTION(gather_axis_of) .FUNCTION(gather_axis_of)
.FUNCTION(reduce_mean_axes_of) .FUNCTION(reduce_mean_axes_of);
.FUNCTION(batch_norm_attrs_of);
#undef FUNCTION #undef FUNCTION
} }