forked from jiuyuan/InfiniTensor
add: onnx ok.
This commit is contained in:
parent
3046dd5901
commit
2aadcb6e9c
|
@ -18,6 +18,7 @@
|
|||
#include "operators/reshape.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -104,6 +105,7 @@ class GraphBuilderObj {
|
|||
const vector<int> &ratio);
|
||||
Operator split(Tensor input, int dim, const vector<int> &ratio);
|
||||
// transpose op
|
||||
Operator transpose(Tensor input, Tensor output, const Shape &perm);
|
||||
// TODO
|
||||
// extend op
|
||||
Operator extend(Tensor input, Tensor output, int dim, int num);
|
||||
|
@ -155,6 +157,7 @@ class GraphBuilderObj {
|
|||
Operator abs(Tensor input, Tensor output);
|
||||
Operator abs(Tensor input);
|
||||
Operator reduceMean(Tensor input, Tensor Output, int axis);
|
||||
Operator erf(Tensor input, Tensor output);
|
||||
// resize op
|
||||
// TODO
|
||||
// membound op
|
||||
|
|
|
@ -38,6 +38,7 @@ enum class OpType {
|
|||
Tanh,
|
||||
Abs,
|
||||
Resize,
|
||||
Erf,
|
||||
//
|
||||
MemBound = 300,
|
||||
};
|
||||
|
|
|
@ -28,4 +28,5 @@ DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
|||
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
||||
DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
||||
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
|
||||
DEFINE_UNARY_OBJ(Erf, OpType::Erf)
|
||||
}; // namespace infini
|
||||
|
|
|
@ -169,7 +169,7 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
"dilations": [1, 1],
|
||||
"pads": [0, 0, 0, 0],
|
||||
"strides": [1, 1]})
|
||||
assert len(node.input) == 2 # bias is not implemented yet
|
||||
# assert len(node.input) == 2 # bias is not implemented yet
|
||||
assert len(node.output) == 1
|
||||
assert attrs["auto_pad"] == "NOTSET"
|
||||
assert len(attrs["pads"]) == 4
|
||||
|
@ -181,7 +181,7 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1],
|
||||
attrs["dilations"][0], attrs["dilations"][1],
|
||||
None if len(node.input) == 2 else ts[node.input[2]])
|
||||
None)
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul
|
||||
elif node.op_type == 'MatMul':
|
||||
|
@ -205,7 +205,8 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
tmpI0 = ts[node.input[0]]
|
||||
else:
|
||||
tmpI0 = gf.tensor([batch, dimA[-2], dimA[-1]], "FLOAT")
|
||||
gf.reshape(ts[node.input[0]], tmpI0, ds[node.input[0]])
|
||||
gf.reshape(ts[node.input[0]], tmpI0, [
|
||||
batch, dimA[-2], dimA[-1]])
|
||||
|
||||
if len(dimB) == 3:
|
||||
tmpI1 = ts[node.input[1]]
|
||||
|
@ -219,7 +220,7 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
gf.matmul(tmpI0, tmpI1, tmpO, False, False)
|
||||
else:
|
||||
tmpO = gf.tensor([batch, dimO[-2], dimO[-1]], "FLOAT")
|
||||
gf.matmul(tmpI0, tmpI1, tmpO, False, False, None)
|
||||
gf.matmul(tmpI0, tmpI1, tmpO, False, False)
|
||||
gf.reshape(tmpO, ts[node.output[0]], ds[node.output[0]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
|
||||
|
@ -493,25 +494,26 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
consts[node.output[0]] = c
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
|
||||
# elif node.op_type == 'Gemm':
|
||||
# attrs = _parse_attribute(node.attribute, {
|
||||
# "alpha": 1.0,
|
||||
# "beta": 1.0,
|
||||
# "transA": 0,
|
||||
# "transB": 0})
|
||||
# assert len(node.input) == 2 or len(node.input) == 3
|
||||
# assert len(node.output) == 1
|
||||
# assert attrs["alpha"] == 1.0
|
||||
# assert attrs["beta"] == 1.0 or len(node.input) == 2
|
||||
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
|
||||
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
|
||||
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
|
||||
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
|
||||
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
|
||||
# g.matmul(tmpI0, tmpI1, tmpO,
|
||||
# attrs["transA"], attrs["transB"],
|
||||
# None if len(node.input) == 2 else ts[node.input[2]])
|
||||
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
|
||||
elif node.op_type == 'Gemm':
|
||||
attrs = _parse_attribute(node.attribute, {
|
||||
"alpha": 1.0,
|
||||
"beta": 1.0,
|
||||
"transA": 0,
|
||||
"transB": 0})
|
||||
assert len(node.input) == 2 or len(node.input) == 3
|
||||
assert len(node.output) == 1
|
||||
assert attrs["alpha"] == 1.0
|
||||
assert attrs["beta"] == 1.0 or len(node.input) == 2
|
||||
i0 = gf.tensor([1] + list(ds[node.input[0]]), "FLOAT")
|
||||
i1 = gf.tensor([1] + list(ds[node.input[1]]), "FLOAT")
|
||||
o0 = gf.tensor([1] + list(ds[node.output[0]]), "FLOAT")
|
||||
gf.reshape(ts[node.input[0]], i0, [1] + list(ds[node.input[0]]))
|
||||
gf.reshape(ts[node.input[1]], i1, [1] + list(ds[node.input[1]]))
|
||||
gf.matmul(i0, i1, o0, attrs["transA"], attrs["transB"])
|
||||
o1 = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
a0 = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
gf.reshape(o0, o1, ds[node.output[0]])
|
||||
gf.add(o1, a0, ts[node.output[0]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool
|
||||
# elif node.op_type == 'GlobalAveragePool':
|
||||
|
@ -542,19 +544,24 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose
|
||||
elif node.op_type == 'Transpose':
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# assert "perm" in attrs
|
||||
# gf.transpose(ts[node.input[0]], ts[node.output[0]], -1,
|
||||
# Perm([PermItem(x) for x in attrs["perm"]]), 0)
|
||||
attrs = _parse_attribute(node.attribute, {})
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
assert "perm" in attrs
|
||||
gf.transpose(ts[node.input[0]], ts[node.output[0]], attrs["perm"])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze
|
||||
elif node.op_type == 'Unsqueeze':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.reshape(ts[node.input[0]],
|
||||
ts[node.output[0]], ts[node.input[1]])
|
||||
ts[node.output[0]], ts[node.output[0]])
|
||||
|
||||
# TODO
|
||||
elif node.op_type == 'Erf':
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
gf.erf(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization
|
||||
# elif node.op_type == "BatchNormalization":
|
||||
|
|
|
@ -7,3 +7,10 @@ class Test_ImportOnnx:
|
|||
runtime = CpuRuntimeObj.getInstance()
|
||||
graphBuilder = GraphBuilderObj(runtime)
|
||||
import_onnx(graphBuilder, '/home/mazx/git/pf-models/bert.bs1.onnx')
|
||||
|
||||
|
||||
class Test_SARDRN:
|
||||
def test_Netname(self):
|
||||
runtime = CpuRuntimeObj.getInstance()
|
||||
graphBuilder = GraphBuilderObj(runtime)
|
||||
import_onnx(graphBuilder, '/home/mazx/git/pf-models/sardrn.bs1.onnx')
|
||||
|
|
|
@ -255,6 +255,14 @@ Operator GraphBuilderObj::split(Tensor input, int dim,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphBuilderObj::transpose(Tensor input, Tensor output,
|
||||
const Shape &perm) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<TransposeObj>(i, o, perm);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
|
||||
int num) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -507,6 +515,13 @@ Operator GraphBuilderObj::reduceMean(Tensor input, Tensor output, int axis) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphBuilderObj::erf(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<ErfObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
|
||||
const TensorVec &outputs,
|
||||
const std::vector<nnet::Tensor> &nnetInputs,
|
||||
|
|
|
@ -47,6 +47,8 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<ConcatObj, std::shared_ptr<ConcatObj>, OperatorObj>(m,
|
||||
"ConcatObj");
|
||||
py::class_<SplitObj, std::shared_ptr<SplitObj>, OperatorObj>(m, "SplitObj");
|
||||
py::class_<TransposeObj, std::shared_ptr<TransposeObj>, OperatorObj>(
|
||||
m, "TransposeObj");
|
||||
py::class_<ExtendObj, std::shared_ptr<ExtendObj>, OperatorObj>(m,
|
||||
"ExtendObj");
|
||||
py::class_<MaxPoolObj, std::shared_ptr<MaxPoolObj>, OperatorObj>(
|
||||
|
@ -126,6 +128,10 @@ void init_graph_builder(py::module &m) {
|
|||
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
|
||||
&GraphBuilderObj::split),
|
||||
policy::reference_internal)
|
||||
.def("transpose",
|
||||
py::overload_cast<Tensor, Tensor, const vector<int> &>(
|
||||
&GraphBuilderObj::transpose),
|
||||
policy::reference_internal)
|
||||
.def("extend",
|
||||
py::overload_cast<Tensor, Tensor, int, int>(
|
||||
&GraphBuilderObj::extend),
|
||||
|
@ -183,6 +189,8 @@ void init_graph_builder(py::module &m) {
|
|||
py::overload_cast<Tensor, Tensor, int>(
|
||||
&GraphBuilderObj::reduceMean),
|
||||
policy::reference_internal)
|
||||
.def("erf", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::erf),
|
||||
policy::reference_internal)
|
||||
.def("memBound",
|
||||
py::overload_cast<const TensorVec &, const TensorVec &,
|
||||
const std::vector<nnet::Tensor> &, nnet::Expr,
|
||||
|
|
|
@ -4,15 +4,6 @@ namespace infini {
|
|||
ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
|
||||
Tensor input1, Tensor output)
|
||||
: OperatorObj(type, {input0, input1}, {output}) {
|
||||
std::cout << "Element: " << int(type) << std::endl;
|
||||
for (auto x : input0->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (auto x : input1->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -22,15 +13,6 @@ ElementWiseObj::inferShape(const TensorVec &inputs) const {
|
|||
// in the opt layer.
|
||||
std::cout << std::endl;
|
||||
const auto A = inputs[0], B = inputs[1];
|
||||
std::cout << "InferShape" << std::endl;
|
||||
for (auto x : A->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (auto x : B->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (A->getDims().size() != B->getDims().size() ||
|
||||
A->getDims() != B->getDims())
|
||||
return {};
|
||||
|
|
|
@ -5,15 +5,6 @@ namespace infini {
|
|||
ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||
int num)
|
||||
: OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) {
|
||||
std::cout << "Extend" << std::endl;
|
||||
for (auto x : input->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (auto x : output->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
|
@ -9,16 +9,20 @@ TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
|
||||
optional<vector<Shape>>
|
||||
TransposeObj::inferShape(const TensorVec &inputs) const {
|
||||
Shape dimsIn = inputs[0]->getDims();
|
||||
const Shape &dimsIn = inputs[0]->getDims();
|
||||
Shape dimsOut;
|
||||
std::unordered_set<size_t> dimSet;
|
||||
for (size_t i = 0; i < perm.size(); ++i) {
|
||||
if (size_t(perm[i]) >= dimsIn.size() ||
|
||||
dimSet.find(perm[i]) != dimSet.end()) {
|
||||
std::cout << i << " " << perm[i] << " "
|
||||
<< int(dimSet.find(perm[i]) != dimSet.end()) << std::endl;
|
||||
return {};
|
||||
}
|
||||
dimsOut.emplace_back(dimsIn[perm[i]]);
|
||||
dimSet.emplace(perm[i]);
|
||||
}
|
||||
std::cout << "transpose Ok" << std::endl;
|
||||
return {{dimsOut}};
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue