forked from jiuyuan/InfiniTensor
feat: 导出 conv 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
afed749b74
commit
8a871c3773
|
@ -475,7 +475,18 @@ class OnnxStub:
|
|||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Conv:
|
||||
raise Exception("TODO")
|
||||
ph, pw, sh, sw, dh, dw = backend.conv_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
pads=[ph, pw],
|
||||
strides=[sh, sw],
|
||||
dilations=[dh, dw],
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.Matmul:
|
||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
|
@ -100,6 +101,20 @@ static int tensor_dtype(Tensor t) {
|
|||
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv);
|
||||
auto conv = dynamic_cast<const ConvObj *>(op.get());
|
||||
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||
conv->getDw(), conv->getSh(), conv->getSw());
|
||||
}
|
||||
|
||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
||||
batchnorm->getTraining());
|
||||
}
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||
|
@ -121,25 +136,19 @@ static Shape reshape_shape_of(Operator op) {
|
|||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
||||
batchnorm->getTraining());
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
#ifdef USE_CUDA
|
||||
.FUNCTION(cuda_runtime)
|
||||
#endif
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(batch_norm_attrs_of)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(reduce_mean_axes_of)
|
||||
.FUNCTION(batch_norm_attrs_of);
|
||||
.FUNCTION(reduce_mean_axes_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue