forked from jiuyuan/InfiniTensor
feat: 导出 conv 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
afed749b74
commit
8a871c3773
|
@ -475,7 +475,18 @@ class OnnxStub:
|
||||||
for (i, it) in enumerate(op.outputs())
|
for (i, it) in enumerate(op.outputs())
|
||||||
]
|
]
|
||||||
if ty == backend.OpType.Conv:
|
if ty == backend.OpType.Conv:
|
||||||
raise Exception("TODO")
|
ph, pw, sh, sw, dh, dw = backend.conv_attrs_of(op)
|
||||||
|
ctx.push_node(
|
||||||
|
make_node(
|
||||||
|
ty.name,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
name,
|
||||||
|
pads=[ph, pw],
|
||||||
|
strides=[sh, sw],
|
||||||
|
dilations=[dh, dw],
|
||||||
|
)
|
||||||
|
)
|
||||||
elif ty == backend.OpType.Matmul:
|
elif ty == backend.OpType.Matmul:
|
||||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||||
elif ty == backend.OpType.BatchNorm:
|
elif ty == backend.OpType.BatchNorm:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
#include "operators/conv.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
#include "operators/reduce_mean.h"
|
#include "operators/reduce_mean.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
@ -100,6 +101,20 @@ static int tensor_dtype(Tensor t) {
|
||||||
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Conv);
|
||||||
|
auto conv = dynamic_cast<const ConvObj *>(op.get());
|
||||||
|
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||||
|
conv->getDw(), conv->getSh(), conv->getSw());
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
||||||
|
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||||
|
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
||||||
|
batchnorm->getTraining());
|
||||||
|
}
|
||||||
|
|
||||||
static int concat_axis_of(Operator op) {
|
static int concat_axis_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||||
|
@ -121,25 +136,19 @@ static Shape reshape_shape_of(Operator op) {
|
||||||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
|
||||||
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
|
||||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
|
||||||
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
|
||||||
batchnorm->getTraining());
|
|
||||||
}
|
|
||||||
|
|
||||||
void export_functions(py::module &m) {
|
void export_functions(py::module &m) {
|
||||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
.FUNCTION(cuda_runtime)
|
.FUNCTION(cuda_runtime)
|
||||||
#endif
|
#endif
|
||||||
|
.FUNCTION(conv_attrs_of)
|
||||||
|
.FUNCTION(batch_norm_attrs_of)
|
||||||
.FUNCTION(tensor_dtype)
|
.FUNCTION(tensor_dtype)
|
||||||
.FUNCTION(reshape_shape_of)
|
.FUNCTION(reshape_shape_of)
|
||||||
.FUNCTION(concat_axis_of)
|
.FUNCTION(concat_axis_of)
|
||||||
.FUNCTION(gather_axis_of)
|
.FUNCTION(gather_axis_of)
|
||||||
.FUNCTION(reduce_mean_axes_of)
|
.FUNCTION(reduce_mean_axes_of);
|
||||||
.FUNCTION(batch_norm_attrs_of);
|
|
||||||
#undef FUNCTION
|
#undef FUNCTION
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue