add: onnx ok.

This commit is contained in:
mazx 2022-10-30 23:08:16 +08:00
parent 3046dd5901
commit 2aadcb6e9c
10 changed files with 77 additions and 58 deletions

View File

@ -18,6 +18,7 @@
#include "operators/reshape.h"
#include "operators/slice.h"
#include "operators/split.h"
#include "operators/transpose.h"
#include "operators/unary.h"
namespace infini {
@ -104,6 +105,7 @@ class GraphBuilderObj {
const vector<int> &ratio);
Operator split(Tensor input, int dim, const vector<int> &ratio);
// transpose op
Operator transpose(Tensor input, Tensor output, const Shape &perm);
// TODO
// extend op
Operator extend(Tensor input, Tensor output, int dim, int num);
@ -155,6 +157,7 @@ class GraphBuilderObj {
Operator abs(Tensor input, Tensor output);
Operator abs(Tensor input);
Operator reduceMean(Tensor input, Tensor Output, int axis);
Operator erf(Tensor input, Tensor output);
// resize op
// TODO
// membound op

View File

@ -38,6 +38,7 @@ enum class OpType {
Tanh,
Abs,
Resize,
Erf,
//
MemBound = 300,
};

View File

@ -28,4 +28,5 @@ DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
DEFINE_UNARY_OBJ(Erf, OpType::Erf)
}; // namespace infini

View File

@ -169,7 +169,7 @@ def import_onnx(gf: GraphBuilder, net: str):
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1]})
assert len(node.input) == 2 # bias is not implemented yet
# assert len(node.input) == 2 # bias is not implemented yet
assert len(node.output) == 1
assert attrs["auto_pad"] == "NOTSET"
assert len(attrs["pads"]) == 4
@ -181,7 +181,7 @@ def import_onnx(gf: GraphBuilder, net: str):
attrs["pads"][0], attrs["pads"][1],
attrs["strides"][0], attrs["strides"][1],
attrs["dilations"][0], attrs["dilations"][1],
None if len(node.input) == 2 else ts[node.input[2]])
None)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul
elif node.op_type == 'MatMul':
@ -205,7 +205,8 @@ def import_onnx(gf: GraphBuilder, net: str):
tmpI0 = ts[node.input[0]]
else:
tmpI0 = gf.tensor([batch, dimA[-2], dimA[-1]], "FLOAT")
gf.reshape(ts[node.input[0]], tmpI0, ds[node.input[0]])
gf.reshape(ts[node.input[0]], tmpI0, [
batch, dimA[-2], dimA[-1]])
if len(dimB) == 3:
tmpI1 = ts[node.input[1]]
@ -219,7 +220,7 @@ def import_onnx(gf: GraphBuilder, net: str):
gf.matmul(tmpI0, tmpI1, tmpO, False, False)
else:
tmpO = gf.tensor([batch, dimO[-2], dimO[-1]], "FLOAT")
gf.matmul(tmpI0, tmpI1, tmpO, False, False, None)
gf.matmul(tmpI0, tmpI1, tmpO, False, False)
gf.reshape(tmpO, ts[node.output[0]], ds[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
@ -493,25 +494,26 @@ def import_onnx(gf: GraphBuilder, net: str):
consts[node.output[0]] = c
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
# elif node.op_type == 'Gemm':
# attrs = _parse_attribute(node.attribute, {
# "alpha": 1.0,
# "beta": 1.0,
# "transA": 0,
# "transB": 0})
# assert len(node.input) == 2 or len(node.input) == 3
# assert len(node.output) == 1
# assert attrs["alpha"] == 1.0
# assert attrs["beta"] == 1.0 or len(node.input) == 2
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.matmul(tmpI0, tmpI1, tmpO,
# attrs["transA"], attrs["transB"],
# None if len(node.input) == 2 else ts[node.input[2]])
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
elif node.op_type == 'Gemm':
attrs = _parse_attribute(node.attribute, {
"alpha": 1.0,
"beta": 1.0,
"transA": 0,
"transB": 0})
assert len(node.input) == 2 or len(node.input) == 3
assert len(node.output) == 1
assert attrs["alpha"] == 1.0
assert attrs["beta"] == 1.0 or len(node.input) == 2
i0 = gf.tensor([1] + list(ds[node.input[0]]), "FLOAT")
i1 = gf.tensor([1] + list(ds[node.input[1]]), "FLOAT")
o0 = gf.tensor([1] + list(ds[node.output[0]]), "FLOAT")
gf.reshape(ts[node.input[0]], i0, [1] + list(ds[node.input[0]]))
gf.reshape(ts[node.input[1]], i1, [1] + list(ds[node.input[1]]))
gf.matmul(i0, i1, o0, attrs["transA"], attrs["transB"])
o1 = gf.tensor(ds[node.output[0]], "FLOAT")
a0 = gf.tensor(ds[node.output[0]], "FLOAT")
gf.reshape(o0, o1, ds[node.output[0]])
gf.add(o1, a0, ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool
# elif node.op_type == 'GlobalAveragePool':
@ -542,19 +544,24 @@ def import_onnx(gf: GraphBuilder, net: str):
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose
elif node.op_type == 'Transpose':
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert "perm" in attrs
# gf.transpose(ts[node.input[0]], ts[node.output[0]], -1,
# Perm([PermItem(x) for x in attrs["perm"]]), 0)
attrs = _parse_attribute(node.attribute, {})
assert len(node.input) == 1
assert len(node.output) == 1
assert "perm" in attrs
gf.transpose(ts[node.input[0]], ts[node.output[0]], attrs["perm"])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze
elif node.op_type == 'Unsqueeze':
assert len(node.input) == 2
assert len(node.output) == 1
gf.reshape(ts[node.input[0]],
ts[node.output[0]], ts[node.input[1]])
ts[node.output[0]], ts[node.output[0]])
# TODO
elif node.op_type == 'Erf':
assert len(node.input) == 1
assert len(node.output) == 1
gf.erf(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization
# elif node.op_type == "BatchNormalization":

View File

@ -7,3 +7,10 @@ class Test_ImportOnnx:
runtime = CpuRuntimeObj.getInstance()
graphBuilder = GraphBuilderObj(runtime)
import_onnx(graphBuilder, '/home/mazx/git/pf-models/bert.bs1.onnx')
class Test_SARDRN:
def test_Netname(self):
runtime = CpuRuntimeObj.getInstance()
graphBuilder = GraphBuilderObj(runtime)
import_onnx(graphBuilder, '/home/mazx/git/pf-models/sardrn.bs1.onnx')

View File

@ -255,6 +255,14 @@ Operator GraphBuilderObj::split(Tensor input, int dim,
return op;
}
Operator GraphBuilderObj::transpose(Tensor input, Tensor output,
const Shape &perm) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
Tensor o = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<TransposeObj>(i, o, perm);
return op;
}
Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -507,6 +515,13 @@ Operator GraphBuilderObj::reduceMean(Tensor input, Tensor output, int axis) {
return op;
}
Operator GraphBuilderObj::erf(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ErfObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
const TensorVec &outputs,
const std::vector<nnet::Tensor> &nnetInputs,

View File

@ -47,6 +47,8 @@ void init_graph_builder(py::module &m) {
py::class_<ConcatObj, std::shared_ptr<ConcatObj>, OperatorObj>(m,
"ConcatObj");
py::class_<SplitObj, std::shared_ptr<SplitObj>, OperatorObj>(m, "SplitObj");
py::class_<TransposeObj, std::shared_ptr<TransposeObj>, OperatorObj>(
m, "TransposeObj");
py::class_<ExtendObj, std::shared_ptr<ExtendObj>, OperatorObj>(m,
"ExtendObj");
py::class_<MaxPoolObj, std::shared_ptr<MaxPoolObj>, OperatorObj>(
@ -126,6 +128,10 @@ void init_graph_builder(py::module &m) {
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
&GraphBuilderObj::split),
policy::reference_internal)
.def("transpose",
py::overload_cast<Tensor, Tensor, const vector<int> &>(
&GraphBuilderObj::transpose),
policy::reference_internal)
.def("extend",
py::overload_cast<Tensor, Tensor, int, int>(
&GraphBuilderObj::extend),
@ -183,6 +189,8 @@ void init_graph_builder(py::module &m) {
py::overload_cast<Tensor, Tensor, int>(
&GraphBuilderObj::reduceMean),
policy::reference_internal)
.def("erf", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::erf),
policy::reference_internal)
.def("memBound",
py::overload_cast<const TensorVec &, const TensorVec &,
const std::vector<nnet::Tensor> &, nnet::Expr,

View File

@ -4,15 +4,6 @@ namespace infini {
ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
Tensor input1, Tensor output)
: OperatorObj(type, {input0, input1}, {output}) {
std::cout << "Element: " << int(type) << std::endl;
for (auto x : input0->getDims()) {
std::cout << x << " ";
}
std::cout << std::endl;
for (auto x : input1->getDims()) {
std::cout << x << " ";
}
std::cout << std::endl;
IT_ASSERT(checkValid(graph));
}
@ -22,15 +13,6 @@ ElementWiseObj::inferShape(const TensorVec &inputs) const {
// in the opt layer.
std::cout << std::endl;
const auto A = inputs[0], B = inputs[1];
std::cout << "InferShape" << std::endl;
for (auto x : A->getDims()) {
std::cout << x << " ";
}
std::cout << std::endl;
for (auto x : B->getDims()) {
std::cout << x << " ";
}
std::cout << std::endl;
if (A->getDims().size() != B->getDims().size() ||
A->getDims() != B->getDims())
return {};

View File

@ -5,15 +5,6 @@ namespace infini {
ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
int num)
: OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) {
std::cout << "Extend" << std::endl;
for (auto x : input->getDims()) {
std::cout << x << " ";
}
std::cout << std::endl;
for (auto x : output->getDims()) {
std::cout << x << " ";
}
IT_ASSERT(checkValid(graph));
}

View File

@ -9,16 +9,20 @@ TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
optional<vector<Shape>>
TransposeObj::inferShape(const TensorVec &inputs) const {
Shape dimsIn = inputs[0]->getDims();
const Shape &dimsIn = inputs[0]->getDims();
Shape dimsOut;
std::unordered_set<size_t> dimSet;
for (size_t i = 0; i < perm.size(); ++i) {
if (size_t(perm[i]) >= dimsIn.size() ||
dimSet.find(perm[i]) != dimSet.end()) {
std::cout << i << " " << perm[i] << " "
<< int(dimSet.find(perm[i]) != dimSet.end()) << std::endl;
return {};
}
dimsOut.emplace_back(dimsIn[perm[i]]);
dimSet.emplace(perm[i]);
}
std::cout << "transpose Ok" << std::endl;
return {{dimsOut}};
}