forked from jiuyuan/InfiniTensor
wip: onnx 导出 (#65)
| Notice | Work in progress |-|- > based on #63 ## Progress 1. [x] 对节点拓扑排序 2. [x] 遍历节点,命名并导出其输出张量(`make_tensor_value_info`) 3. [x] 识别图的输入张量,命名并导出(`make_tensor_value_info`) 4. [x] 根据节点类型,识别权重及属性,导出节点(`make_node`) 5. [x] `make_graph` -> `check_graph` -> `make_model` -> `check_model`
This commit is contained in:
commit
dd5d091dbc
|
@ -8,13 +8,10 @@ class GraphObj : public Object {
|
|||
protected:
|
||||
Runtime runtime;
|
||||
TensorVec tensors;
|
||||
// TODO: whether to record input and output tensors
|
||||
// TensorVec inputs;
|
||||
// TensorVec outputs;
|
||||
OpVec ops;
|
||||
|
||||
public:
|
||||
GraphObj(Runtime runtime) : runtime(runtime){};
|
||||
explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
|
||||
GraphObj(Runtime runtime, OpVec ops_in);
|
||||
string toString() const override;
|
||||
Runtime getRuntime() const { return runtime; }
|
||||
|
@ -23,10 +20,23 @@ class GraphObj : public Object {
|
|||
Tensor addTensor(const Tensor &tensor);
|
||||
TensorVec addTensor(const TensorVec &tensors);
|
||||
Tensor cloneTensor(const Tensor &tensor) {
|
||||
auto ret = addTensor(tensor->clone(runtime));
|
||||
return ret;
|
||||
return addTensor(tensor->clone(runtime));
|
||||
}
|
||||
|
||||
const TensorVec &getTensors() const { return tensors; }
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
|
||||
/**
|
||||
* Sort the nodes in topological order.
|
||||
* It returns true if the sorting is successful.
|
||||
* Otherwise false is returned, means that there are rings in the graph,
|
||||
* so the topological sorting fails.
|
||||
*/
|
||||
bool topo_sort();
|
||||
|
||||
void dataMalloc();
|
||||
|
||||
/**
|
||||
* @brief Add an operator and create its outputs. Output tensor arguments
|
||||
* should be empty Refs (e.g., nullptr).
|
||||
|
@ -47,25 +57,27 @@ class GraphObj : public Object {
|
|||
return op;
|
||||
}
|
||||
|
||||
const TensorVec &getTensors() const { return tensors; }
|
||||
const TensorVec getInputs() const {
|
||||
/**
|
||||
* @brief Gets input tensors of this graph.
|
||||
*/
|
||||
inline TensorVec getInputs() const {
|
||||
TensorVec ret;
|
||||
for (auto t : tensors)
|
||||
for (const auto &t : tensors)
|
||||
if (!t->getOutputOf())
|
||||
ret.emplace_back(t);
|
||||
return ret;
|
||||
}
|
||||
const TensorVec getOutputs() const {
|
||||
|
||||
/**
|
||||
* @brief Gets output tensors of this graph.
|
||||
*/
|
||||
inline TensorVec getOutputs() const {
|
||||
TensorVec ret;
|
||||
for (auto t : tensors)
|
||||
for (const auto &t : tensors)
|
||||
if (t->getInputOf().empty())
|
||||
ret.emplace_back(t);
|
||||
return ret;
|
||||
}
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
|
||||
void dataMalloc();
|
||||
|
||||
private:
|
||||
/**
|
||||
|
@ -73,9 +85,10 @@ class GraphObj : public Object {
|
|||
*/
|
||||
void addOperatorAndConnect(const Operator &op);
|
||||
|
||||
// TODO: move to another class
|
||||
// bool exportOnnx(const char *path);
|
||||
// bool importOnnx(const char *net);
|
||||
/**
|
||||
* @brief If the nodes is sorted in topological order.
|
||||
*/
|
||||
bool sorted;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -38,12 +38,14 @@ class GraphHandlerObj {
|
|||
|
||||
Tensor tensor(Shape dims, int dtype);
|
||||
|
||||
//------ operators
|
||||
|
||||
inline OpVec operators() { return g->getOperators(); }
|
||||
|
||||
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw);
|
||||
|
||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||
Tensor bias, ActType act);
|
||||
|
||||
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
||||
Tensor scale, Tensor bias, float momentum, float eps,
|
||||
bool training);
|
||||
|
@ -77,6 +79,10 @@ class GraphHandlerObj {
|
|||
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
||||
const optional<vector<int>> &axes);
|
||||
|
||||
//------ modifiers
|
||||
|
||||
inline bool topo_sort() { return g->topo_sort(); }
|
||||
|
||||
//------ runtime
|
||||
|
||||
inline void data_malloc() { g->dataMalloc(); }
|
||||
|
|
|
@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj {
|
|||
int numOutputs() const override { return 1; }
|
||||
|
||||
bool isReduced(int idx) const;
|
||||
const set<int> &getAxes() const { return axes; }
|
||||
bool getKeepDims() const { return keepDims; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getShape() const { return dims; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
|
|
@ -1,17 +1,39 @@
|
|||
import onnx, backend
|
||||
import backend
|
||||
from onnx import (
|
||||
ModelProto,
|
||||
TensorProto,
|
||||
NodeProto,
|
||||
AttributeProto,
|
||||
TensorShapeProto,
|
||||
ValueInfoProto,
|
||||
)
|
||||
from onnx.helper import (
|
||||
make_node,
|
||||
make_tensor_value_info,
|
||||
make_tensor,
|
||||
make_graph,
|
||||
make_model,
|
||||
)
|
||||
from onnx.checker import (
|
||||
check_graph,
|
||||
check_model,
|
||||
check_node,
|
||||
check_value_info,
|
||||
check_tensor,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Tuple, Sequence
|
||||
from functools import reduce
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
||||
|
||||
def from_onnx(model: onnx.ModelProto):
|
||||
def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandlerObj(runtime)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.TensorObj] = dict()
|
||||
data: Dict[str, onnx.TensorProto] = dict()
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
|
@ -303,7 +325,164 @@ def from_onnx(model: onnx.ModelProto):
|
|||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
|
||||
def parse_onnx(model: onnx.ModelProto):
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
# counts the occurrence times of each operator for naming
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
# counts input and output tensors for naming
|
||||
count_in, count_out = 0, 0
|
||||
# saves nodes (operators)
|
||||
nodes: List[NodeProto] = []
|
||||
# saves global input tensors
|
||||
inputs: List[ValueInfoProto] = []
|
||||
# saves global output tensors
|
||||
outputs: List[ValueInfoProto] = []
|
||||
# saves global input tensors
|
||||
initializers: List[TensorProto] = []
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||
self.names[op] = name
|
||||
self.count_op[ty] += 1
|
||||
return ty, name
|
||||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||
self.names[tensor] = name
|
||||
return name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
name = self.names.get(tensor)
|
||||
# means that this input is a global input
|
||||
if name is None:
|
||||
self.count_in += 1
|
||||
name = "input{}".format(self.count_in)
|
||||
self.names[tensor] = name
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.inputs.append(value_info)
|
||||
|
||||
return name
|
||||
|
||||
def push_data_input(
|
||||
self,
|
||||
node_name: str,
|
||||
attr_name: str,
|
||||
elem_type: int,
|
||||
shape: Sequence[int],
|
||||
vals: Any,
|
||||
) -> str:
|
||||
name = "{}_{}".format(node_name, attr_name)
|
||||
value_info = make_tensor_value_info(name, elem_type, shape)
|
||||
tensor = make_tensor(name, elem_type, shape, vals)
|
||||
check_value_info(value_info)
|
||||
check_tensor(tensor)
|
||||
self.inputs.append(value_info)
|
||||
self.initializers.append(tensor)
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
check_node(node)
|
||||
self.nodes.append(node)
|
||||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
print()
|
||||
print(ctx.names)
|
||||
print()
|
||||
print(ctx.inputs)
|
||||
print()
|
||||
print(ctx.outputs)
|
||||
print()
|
||||
print(ctx.nodes)
|
||||
|
||||
graph = make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
)
|
||||
check_graph(graph)
|
||||
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
|
||||
return model
|
||||
|
||||
# 拓扑排序
|
||||
if not graph.topo_sort():
|
||||
raise Exception("Sorting fails")
|
||||
|
||||
ops = graph.operators() # 图中所有算子(节点)
|
||||
|
||||
ctx = Context()
|
||||
|
||||
for op in ops:
|
||||
ty, name = ctx.name_op(op)
|
||||
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||
outputs = [
|
||||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Matmul:
|
||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
raise Exception("TODO")
|
||||
elif ty in [
|
||||
backend.OpType.Add,
|
||||
backend.OpType.Sub,
|
||||
backend.OpType.Mul,
|
||||
backend.OpType.Div,
|
||||
backend.OpType.Pow,
|
||||
backend.OpType.Relu,
|
||||
backend.OpType.Sigmoid,
|
||||
backend.OpType.Tanh,
|
||||
backend.OpType.Softmax,
|
||||
backend.OpType.Abs,
|
||||
backend.OpType.Identity,
|
||||
]:
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
shape = backend.reshape_shape_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name,
|
||||
"shape",
|
||||
TensorProto.INT32,
|
||||
[len(shape)],
|
||||
shape,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Concat:
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.Gather:
|
||||
axis = backend.gather_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
axes = backend.reduce_mean_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
|
||||
elif ty == backend.OpType.Slice:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Pad:
|
||||
raise Exception("TODO")
|
||||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
return ctx.build(name)
|
||||
|
||||
|
||||
def parse_onnx(model: ModelProto):
|
||||
print()
|
||||
|
||||
for field in [
|
||||
|
@ -339,34 +518,32 @@ def parse_onnx(model: onnx.ModelProto):
|
|||
print(" {}".format(node.name))
|
||||
|
||||
|
||||
def _parse_attribute(
|
||||
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
|
||||
) -> Dict[str, Any]:
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
if attr.type == onnx.AttributeProto.INT:
|
||||
if attr.type == AttributeProto.INT:
|
||||
attrs[attr.name] = attr.i
|
||||
elif attr.type == onnx.AttributeProto.INTS:
|
||||
elif attr.type == AttributeProto.INTS:
|
||||
attrs[attr.name] = attr.ints
|
||||
elif attr.type == onnx.AttributeProto.FLOAT:
|
||||
elif attr.type == AttributeProto.FLOAT:
|
||||
attrs[attr.name] = attr.f
|
||||
elif attr.type == onnx.AttributeProto.STRING:
|
||||
elif attr.type == AttributeProto.STRING:
|
||||
attrs[attr.name] = attr.s
|
||||
elif attr.type == onnx.AttributeProto.TENSOR:
|
||||
elif attr.type == AttributeProto.TENSOR:
|
||||
attrs[attr.name] = attr.t
|
||||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||
return attrs
|
||||
|
||||
|
||||
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
|
||||
if tensor.data_type == onnx.TensorProto.INT32:
|
||||
def _parse_data(tensor: TensorProto) -> List[int]:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
return [int(i) for i in tensor.int32_data]
|
||||
elif tensor.data_type == onnx.TensorProto.INT64:
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
return [int(i) for i in tensor.int64_data]
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
|
||||
def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]:
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
|
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
|
@ -293,11 +293,20 @@ class TestStringMethods(unittest.TestCase):
|
|||
parse_onnx(model)
|
||||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandlerObj(runtime)
|
||||
i = handler.tensor([1, 2, 3], 12)
|
||||
w = handler.tensor([1, 3, 4], 12)
|
||||
o = handler.tensor([1, 2, 4], 12)
|
||||
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
a = handler.tensor([1, 2, 3], 12)
|
||||
b = handler.tensor([1, 2, 3], 12)
|
||||
c = handler.tensor([1, 2, 3], 12)
|
||||
d = handler.tensor([1, 2, 3], 12)
|
||||
e = handler.tensor([1, 2, 3], 12)
|
||||
|
||||
x = handler.add(
|
||||
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
|
||||
)
|
||||
y = handler.tensor([3, 2, 1], 12)
|
||||
handler.reshape(x, y, [3, 2, 1])
|
||||
|
||||
to_onnx(handler, "test_frontend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#include "core/graph.h"
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
|
||||
namespace infini {
|
||||
|
||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
||||
: runtime(runtime), sorted(false) {
|
||||
map<UidBaseType, Tensor> tensorPool;
|
||||
// Clone tensors
|
||||
for (const auto &op : ops_in) {
|
||||
|
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
|||
}
|
||||
|
||||
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||
sorted = false;
|
||||
ops.push_back(op);
|
||||
for (auto &input : op->getInputs()) {
|
||||
input->addInputOf(op);
|
||||
|
@ -66,6 +69,53 @@ string GraphObj::toString() const {
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
bool GraphObj::topo_sort() {
|
||||
if (this->sorted)
|
||||
return true;
|
||||
|
||||
// std::unordered_set<Tensor> inputs;
|
||||
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||
std::vector<Operator> sorted;
|
||||
|
||||
while (!waiting.empty()) {
|
||||
// Any node is move to sorted in this loop.
|
||||
auto modified = false;
|
||||
// Find head nodes.
|
||||
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||
const auto &this_inputs = (*it)->getInputs();
|
||||
// If none of the input tensors is in waiting list,
|
||||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getOutputOf();
|
||||
return src // If the source node is in the waiting list,
|
||||
// means that this node is not the head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
// This tensor has no source node,
|
||||
// it must be a input tensor.
|
||||
: (/*inputs.insert(input),*/ true);
|
||||
});
|
||||
// Moves head node to sorted.
|
||||
if (is_head) {
|
||||
modified = true;
|
||||
sorted.emplace_back(std::move(*it));
|
||||
it = waiting.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
// Waiting list never modifies during a pass,
|
||||
// sorting fails.
|
||||
if (!modified) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Done.
|
||||
this->ops = std::move(sorted);
|
||||
return this->sorted = true;
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc() {
|
||||
for (auto &tensor : tensors) {
|
||||
tensor->dataMalloc();
|
||||
|
@ -73,15 +123,12 @@ void GraphObj::dataMalloc() {
|
|||
}
|
||||
|
||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime);
|
||||
tensors.emplace_back(tensor);
|
||||
return tensor;
|
||||
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
|
||||
}
|
||||
|
||||
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
||||
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
|
||||
tensors.emplace_back(tensor);
|
||||
return tensor;
|
||||
return tensors.emplace_back(tensor);
|
||||
}
|
||||
|
||||
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
||||
|
@ -98,4 +145,4 @@ OpVec GraphObj::getComputeOps() const {
|
|||
return opList;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
@ -21,88 +25,151 @@ void register_operator_timer(py::module &m) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void export_values(py::module &m) {
|
||||
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
|
||||
|
||||
py::enum_<ActType>(m, "ActType")
|
||||
.value("Linear", ActType::None) // `None` is Python keyword
|
||||
.VALUE(ActType, Relu)
|
||||
.VALUE(ActType, Sigmoid)
|
||||
.VALUE(ActType, Tanh)
|
||||
.export_values();
|
||||
|
||||
py::enum_<OpType>(m, "OpType")
|
||||
.VALUE(OpType, Unknown)
|
||||
.VALUE(OpType, Conv)
|
||||
.VALUE(OpType, Matmul)
|
||||
.VALUE(OpType, ConvTrans)
|
||||
.VALUE(OpType, G2BMM)
|
||||
.VALUE(OpType, GBMM)
|
||||
.VALUE(OpType, Pad)
|
||||
.VALUE(OpType, Slice)
|
||||
.VALUE(OpType, Concat)
|
||||
.VALUE(OpType, Split)
|
||||
.VALUE(OpType, Transpose)
|
||||
.VALUE(OpType, Extend)
|
||||
.VALUE(OpType, MaxPool)
|
||||
.VALUE(OpType, AvgPool)
|
||||
.VALUE(OpType, Add)
|
||||
.VALUE(OpType, Sub)
|
||||
.VALUE(OpType, Mul)
|
||||
.VALUE(OpType, Div)
|
||||
.VALUE(OpType, Pow)
|
||||
.VALUE(OpType, Gather)
|
||||
.VALUE(OpType, ReduceMean)
|
||||
.VALUE(OpType, Reshape)
|
||||
.VALUE(OpType, Flatten)
|
||||
.VALUE(OpType, Identity)
|
||||
.VALUE(OpType, BatchNorm)
|
||||
.VALUE(OpType, Softmax)
|
||||
.VALUE(OpType, Activation)
|
||||
.VALUE(OpType, Relu)
|
||||
.VALUE(OpType, Sigmoid)
|
||||
.VALUE(OpType, Tanh)
|
||||
.VALUE(OpType, Abs)
|
||||
.VALUE(OpType, Resize)
|
||||
.VALUE(OpType, MemBound)
|
||||
.export_values();
|
||||
|
||||
#undef VALUE
|
||||
}
|
||||
|
||||
static int tensor_dtype(Tensor t) {
|
||||
if (t->getDType() == DataType::Float32)
|
||||
return OnnxDType::FLOAT;
|
||||
if (t->getDType() == DataType::UInt32)
|
||||
return OnnxDType::UINT32;
|
||||
if (t->getDType() == DataType::UInt8)
|
||||
return OnnxDType::UINT8;
|
||||
if (t->getDType() == DataType::Int8)
|
||||
return OnnxDType::INT8;
|
||||
if (t->getDType() == DataType::UInt16)
|
||||
return OnnxDType::UINT16;
|
||||
if (t->getDType() == DataType::Int16)
|
||||
return OnnxDType::INT16;
|
||||
if (t->getDType() == DataType::Int32)
|
||||
return OnnxDType::INT32;
|
||||
if (t->getDType() == DataType::Int64)
|
||||
return OnnxDType::INT64;
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||
}
|
||||
|
||||
static int gather_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Gather);
|
||||
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
|
||||
}
|
||||
|
||||
static vector<int> reduce_mean_axes_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
|
||||
auto &set = dynamic_cast<const ReduceMeanObj *>(op.get())->getAxes();
|
||||
return vector(set.begin(), set.end());
|
||||
}
|
||||
|
||||
static Shape reshape_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(reduce_mean_axes_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntimeObj");
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
||||
py::enum_<ActType>(m, "ActType")
|
||||
.value("Linear", ActType::None) // `None` is Python keyword
|
||||
.value("Relu", ActType::Relu)
|
||||
.value("Sigmoid", ActType::Sigmoid)
|
||||
.value("Tanh", ActType::Tanh)
|
||||
.export_values();
|
||||
py::class_<Handler>(m, "GraphHandlerObj")
|
||||
m, "CpuRuntime");
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
policy::reference)
|
||||
.def("outputs",
|
||||
py::overload_cast<>(&OperatorObj::getOutputs, py::const_),
|
||||
policy::reference);
|
||||
py::class_<Handler>(m, "GraphHandler")
|
||||
.def(py::init<Runtime>())
|
||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||
policy::move)
|
||||
.def("tensor", &Handler::tensor, policy::move)
|
||||
.def("conv", &Handler::conv, policy::move)
|
||||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&Handler::matmul),
|
||||
policy::move)
|
||||
.def("batchNorm",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
||||
float, float, bool>(&Handler::batchNorm),
|
||||
policy::move)
|
||||
.def("maxPool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&Handler::maxPool),
|
||||
policy::move)
|
||||
.def("avgPool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&Handler::avgPool),
|
||||
policy::move)
|
||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||
policy::move)
|
||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||
policy::move)
|
||||
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
|
||||
policy::move)
|
||||
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
|
||||
policy::move)
|
||||
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
|
||||
policy::move)
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
|
||||
policy::move)
|
||||
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
|
||||
policy::move)
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
|
||||
policy::move)
|
||||
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
|
||||
policy::move)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
||||
policy::move)
|
||||
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
||||
policy::move)
|
||||
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
|
||||
policy::move)
|
||||
.def("reshape",
|
||||
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
|
||||
policy::move)
|
||||
.def("concat",
|
||||
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
|
||||
policy::move)
|
||||
.def("gather",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
|
||||
policy::move)
|
||||
.def("reduceMean",
|
||||
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
||||
bool>(&Handler::reduceMean),
|
||||
policy::move)
|
||||
.def("slice",
|
||||
py::overload_cast<
|
||||
Tensor, Tensor, const vector<int> &, const vector<int> &,
|
||||
const optional<vector<int>> &, const optional<vector<int>> &>(
|
||||
&Handler::slice),
|
||||
policy::move)
|
||||
.def("pad",
|
||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||
const optional<vector<int>> &>(&Handler::pad),
|
||||
policy::move)
|
||||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNorm", &Handler::batchNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
.def("sub", &Handler::sub, policy::move)
|
||||
.def("mul", &Handler::mul, policy::move)
|
||||
.def("div", &Handler::div, policy::move)
|
||||
.def("pow", &Handler::pow, policy::move)
|
||||
.def("relu", &Handler::relu, policy::move)
|
||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||
.def("tanh", &Handler::tanh, policy::move)
|
||||
.def("softmax", &Handler::softmax, policy::move)
|
||||
.def("abs", &Handler::abs, policy::move)
|
||||
.def("identity", &Handler::identity, policy::move)
|
||||
.def("flatten", &Handler::flatten, policy::move)
|
||||
.def("reshape", &Handler::reshape, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("reduceMean", &Handler::reduceMean, policy::move)
|
||||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic);
|
||||
}
|
||||
|
@ -111,5 +178,7 @@ void init_graph_builder(py::module &m) {
|
|||
|
||||
PYBIND11_MODULE(backend, m) {
|
||||
infini::register_operator_timer(m);
|
||||
infini::export_values(m);
|
||||
infini::export_functions(m);
|
||||
infini::init_graph_builder(m);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/blob.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
|
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
|
|||
EXPECT_TRUE(o0->equalData(ans));
|
||||
}
|
||||
|
||||
TEST(Graph, topological) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
|
||||
auto ops = std::vector{
|
||||
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
|
||||
g->addOpWithOutputs<AddObj>(abc, d, abcd),
|
||||
g->addOpWithOutputs<AddObj>(ab, c, abc),
|
||||
g->addOpWithOutputs<AddObj>(a, b, ab),
|
||||
};
|
||||
|
||||
{
|
||||
auto p = ops.begin();
|
||||
auto q = g->getOperators().begin();
|
||||
while (p != ops.end()) {
|
||||
EXPECT_EQ(*p++, *q++);
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_TRUE(g->topo_sort());
|
||||
|
||||
{
|
||||
auto p = ops.rbegin();
|
||||
auto q = g->getOperators().begin();
|
||||
while (p != ops.rend()) {
|
||||
EXPECT_EQ(*p++, *q++);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
||||
TEST(Graph, perf_engine) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
|
Loading…
Reference in New Issue