forked from jiuyuan/InfiniTensor
wip: onnx 导出 (#65)
| Notice | Work in progress |-|- > based on #63 ## Progress 1. [x] 对节点拓扑排序 2. [x] 遍历节点,命名并导出其输出张量(`make_tensor_value_info`) 3. [x] 识别图的输入张量,命名并导出(`make_tensor_value_info`) 4. [x] 根据节点类型,识别权重及属性,导出节点(`make_node`) 5. [x] `make_graph` -> `check_graph` -> `make_model` -> `check_model`
This commit is contained in:
commit
dd5d091dbc
|
@ -8,13 +8,10 @@ class GraphObj : public Object {
|
||||||
protected:
|
protected:
|
||||||
Runtime runtime;
|
Runtime runtime;
|
||||||
TensorVec tensors;
|
TensorVec tensors;
|
||||||
// TODO: whether to record input and output tensors
|
|
||||||
// TensorVec inputs;
|
|
||||||
// TensorVec outputs;
|
|
||||||
OpVec ops;
|
OpVec ops;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GraphObj(Runtime runtime) : runtime(runtime){};
|
explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
|
||||||
GraphObj(Runtime runtime, OpVec ops_in);
|
GraphObj(Runtime runtime, OpVec ops_in);
|
||||||
string toString() const override;
|
string toString() const override;
|
||||||
Runtime getRuntime() const { return runtime; }
|
Runtime getRuntime() const { return runtime; }
|
||||||
|
@ -23,10 +20,23 @@ class GraphObj : public Object {
|
||||||
Tensor addTensor(const Tensor &tensor);
|
Tensor addTensor(const Tensor &tensor);
|
||||||
TensorVec addTensor(const TensorVec &tensors);
|
TensorVec addTensor(const TensorVec &tensors);
|
||||||
Tensor cloneTensor(const Tensor &tensor) {
|
Tensor cloneTensor(const Tensor &tensor) {
|
||||||
auto ret = addTensor(tensor->clone(runtime));
|
return addTensor(tensor->clone(runtime));
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TensorVec &getTensors() const { return tensors; }
|
||||||
|
const OpVec &getOperators() const { return ops; }
|
||||||
|
OpVec getComputeOps() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sort the nodes in topological order.
|
||||||
|
* It returns true if the sorting is successful.
|
||||||
|
* Otherwise false is returned, means that there are rings in the graph,
|
||||||
|
* so the topological sorting fails.
|
||||||
|
*/
|
||||||
|
bool topo_sort();
|
||||||
|
|
||||||
|
void dataMalloc();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add an operator and create its outputs. Output tensor arguments
|
* @brief Add an operator and create its outputs. Output tensor arguments
|
||||||
* should be empty Refs (e.g., nullptr).
|
* should be empty Refs (e.g., nullptr).
|
||||||
|
@ -47,25 +57,27 @@ class GraphObj : public Object {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
const TensorVec &getTensors() const { return tensors; }
|
/**
|
||||||
const TensorVec getInputs() const {
|
* @brief Gets input tensors of this graph.
|
||||||
|
*/
|
||||||
|
inline TensorVec getInputs() const {
|
||||||
TensorVec ret;
|
TensorVec ret;
|
||||||
for (auto t : tensors)
|
for (const auto &t : tensors)
|
||||||
if (!t->getOutputOf())
|
if (!t->getOutputOf())
|
||||||
ret.emplace_back(t);
|
ret.emplace_back(t);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
const TensorVec getOutputs() const {
|
|
||||||
|
/**
|
||||||
|
* @brief Gets output tensors of this graph.
|
||||||
|
*/
|
||||||
|
inline TensorVec getOutputs() const {
|
||||||
TensorVec ret;
|
TensorVec ret;
|
||||||
for (auto t : tensors)
|
for (const auto &t : tensors)
|
||||||
if (t->getInputOf().empty())
|
if (t->getInputOf().empty())
|
||||||
ret.emplace_back(t);
|
ret.emplace_back(t);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
const OpVec &getOperators() const { return ops; }
|
|
||||||
OpVec getComputeOps() const;
|
|
||||||
|
|
||||||
void dataMalloc();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
|
@ -73,9 +85,10 @@ class GraphObj : public Object {
|
||||||
*/
|
*/
|
||||||
void addOperatorAndConnect(const Operator &op);
|
void addOperatorAndConnect(const Operator &op);
|
||||||
|
|
||||||
// TODO: move to another class
|
/**
|
||||||
// bool exportOnnx(const char *path);
|
* @brief If the nodes is sorted in topological order.
|
||||||
// bool importOnnx(const char *net);
|
*/
|
||||||
|
bool sorted;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -38,12 +38,14 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
Tensor tensor(Shape dims, int dtype);
|
Tensor tensor(Shape dims, int dtype);
|
||||||
|
|
||||||
|
//------ operators
|
||||||
|
|
||||||
|
inline OpVec operators() { return g->getOperators(); }
|
||||||
|
|
||||||
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||||
int sh, int sw, int dh, int dw);
|
int sh, int sw, int dh, int dw);
|
||||||
|
|
||||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||||
Tensor bias, ActType act);
|
Tensor bias, ActType act);
|
||||||
|
|
||||||
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
||||||
Tensor scale, Tensor bias, float momentum, float eps,
|
Tensor scale, Tensor bias, float momentum, float eps,
|
||||||
bool training);
|
bool training);
|
||||||
|
@ -77,6 +79,10 @@ class GraphHandlerObj {
|
||||||
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
||||||
const optional<vector<int>> &axes);
|
const optional<vector<int>> &axes);
|
||||||
|
|
||||||
|
//------ modifiers
|
||||||
|
|
||||||
|
inline bool topo_sort() { return g->topo_sort(); }
|
||||||
|
|
||||||
//------ runtime
|
//------ runtime
|
||||||
|
|
||||||
inline void data_malloc() { g->dataMalloc(); }
|
inline void data_malloc() { g->dataMalloc(); }
|
||||||
|
|
|
@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj {
|
||||||
int numOutputs() const override { return 1; }
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
bool isReduced(int idx) const;
|
bool isReduced(int idx) const;
|
||||||
|
const set<int> &getAxes() const { return axes; }
|
||||||
bool getKeepDims() const { return keepDims; }
|
bool getKeepDims() const { return keepDims; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
|
||||||
int numInputs() const override { return 1; }
|
int numInputs() const override { return 1; }
|
||||||
int numOutputs() const override { return 1; }
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
inline Shape getShape() const { return dims; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
|
|
|
@ -1,17 +1,39 @@
|
||||||
import onnx, backend
|
import backend
|
||||||
|
from onnx import (
|
||||||
|
ModelProto,
|
||||||
|
TensorProto,
|
||||||
|
NodeProto,
|
||||||
|
AttributeProto,
|
||||||
|
TensorShapeProto,
|
||||||
|
ValueInfoProto,
|
||||||
|
)
|
||||||
|
from onnx.helper import (
|
||||||
|
make_node,
|
||||||
|
make_tensor_value_info,
|
||||||
|
make_tensor,
|
||||||
|
make_graph,
|
||||||
|
make_model,
|
||||||
|
)
|
||||||
|
from onnx.checker import (
|
||||||
|
check_graph,
|
||||||
|
check_model,
|
||||||
|
check_node,
|
||||||
|
check_value_info,
|
||||||
|
check_tensor,
|
||||||
|
)
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any, Tuple, Sequence
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
|
||||||
def from_onnx(model: onnx.ModelProto):
|
def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
||||||
model = infer_shapes(model)
|
model = infer_shapes(model)
|
||||||
handler = backend.GraphHandlerObj(runtime)
|
handler = backend.GraphHandler(runtime)
|
||||||
|
|
||||||
tensors: Dict[str, backend.TensorObj] = dict()
|
tensors: Dict[str, backend.Tensor] = dict()
|
||||||
data: Dict[str, onnx.TensorProto] = dict()
|
data: Dict[str, TensorProto] = dict()
|
||||||
|
|
||||||
for input in model.graph.input:
|
for input in model.graph.input:
|
||||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||||
|
@ -303,7 +325,164 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: onnx.ModelProto):
|
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||||
|
class Context:
|
||||||
|
# saves object names, including tensors and operators
|
||||||
|
names: Dict[Any, str] = dict()
|
||||||
|
# counts the occurrence times of each operator for naming
|
||||||
|
count_op: Dict[backend.OpType, int] = dict()
|
||||||
|
# counts input and output tensors for naming
|
||||||
|
count_in, count_out = 0, 0
|
||||||
|
# saves nodes (operators)
|
||||||
|
nodes: List[NodeProto] = []
|
||||||
|
# saves global input tensors
|
||||||
|
inputs: List[ValueInfoProto] = []
|
||||||
|
# saves global output tensors
|
||||||
|
outputs: List[ValueInfoProto] = []
|
||||||
|
# saves global input tensors
|
||||||
|
initializers: List[TensorProto] = []
|
||||||
|
|
||||||
|
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||||
|
ty = op.op_type()
|
||||||
|
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||||
|
self.names[op] = name
|
||||||
|
self.count_op[ty] += 1
|
||||||
|
return ty, name
|
||||||
|
|
||||||
|
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||||
|
self.names[tensor] = name
|
||||||
|
return name
|
||||||
|
|
||||||
|
def push_input(self, tensor: backend.Tensor) -> str:
|
||||||
|
name = self.names.get(tensor)
|
||||||
|
# means that this input is a global input
|
||||||
|
if name is None:
|
||||||
|
self.count_in += 1
|
||||||
|
name = "input{}".format(self.count_in)
|
||||||
|
self.names[tensor] = name
|
||||||
|
shape = tensor.shape()
|
||||||
|
dtype = backend.tensor_dtype(tensor)
|
||||||
|
value_info = make_tensor_value_info(name, dtype, shape)
|
||||||
|
check_value_info(value_info)
|
||||||
|
self.inputs.append(value_info)
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
def push_data_input(
|
||||||
|
self,
|
||||||
|
node_name: str,
|
||||||
|
attr_name: str,
|
||||||
|
elem_type: int,
|
||||||
|
shape: Sequence[int],
|
||||||
|
vals: Any,
|
||||||
|
) -> str:
|
||||||
|
name = "{}_{}".format(node_name, attr_name)
|
||||||
|
value_info = make_tensor_value_info(name, elem_type, shape)
|
||||||
|
tensor = make_tensor(name, elem_type, shape, vals)
|
||||||
|
check_value_info(value_info)
|
||||||
|
check_tensor(tensor)
|
||||||
|
self.inputs.append(value_info)
|
||||||
|
self.initializers.append(tensor)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def push_node(self, node: NodeProto) -> None:
|
||||||
|
check_node(node)
|
||||||
|
self.nodes.append(node)
|
||||||
|
|
||||||
|
def build(self, name: str) -> ModelProto:
|
||||||
|
print()
|
||||||
|
print(ctx.names)
|
||||||
|
print()
|
||||||
|
print(ctx.inputs)
|
||||||
|
print()
|
||||||
|
print(ctx.outputs)
|
||||||
|
print()
|
||||||
|
print(ctx.nodes)
|
||||||
|
|
||||||
|
graph = make_graph(
|
||||||
|
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||||
|
)
|
||||||
|
check_graph(graph)
|
||||||
|
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 拓扑排序
|
||||||
|
if not graph.topo_sort():
|
||||||
|
raise Exception("Sorting fails")
|
||||||
|
|
||||||
|
ops = graph.operators() # 图中所有算子(节点)
|
||||||
|
|
||||||
|
ctx = Context()
|
||||||
|
|
||||||
|
for op in ops:
|
||||||
|
ty, name = ctx.name_op(op)
|
||||||
|
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||||
|
outputs = [
|
||||||
|
ctx.push_output("{}_{}".format(name, i), it)
|
||||||
|
for (i, it) in enumerate(op.outputs())
|
||||||
|
]
|
||||||
|
if ty == backend.OpType.Matmul:
|
||||||
|
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||||
|
elif ty == backend.OpType.BatchNorm:
|
||||||
|
raise Exception("TODO")
|
||||||
|
elif ty == backend.OpType.MaxPool:
|
||||||
|
raise Exception("TODO")
|
||||||
|
elif ty == backend.OpType.AvgPool:
|
||||||
|
raise Exception("TODO")
|
||||||
|
elif ty in [
|
||||||
|
backend.OpType.Add,
|
||||||
|
backend.OpType.Sub,
|
||||||
|
backend.OpType.Mul,
|
||||||
|
backend.OpType.Div,
|
||||||
|
backend.OpType.Pow,
|
||||||
|
backend.OpType.Relu,
|
||||||
|
backend.OpType.Sigmoid,
|
||||||
|
backend.OpType.Tanh,
|
||||||
|
backend.OpType.Softmax,
|
||||||
|
backend.OpType.Abs,
|
||||||
|
backend.OpType.Identity,
|
||||||
|
]:
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
|
elif ty == backend.OpType.Flatten:
|
||||||
|
raise Exception("TODO")
|
||||||
|
elif ty == backend.OpType.Reshape:
|
||||||
|
shape = backend.reshape_shape_of(op)
|
||||||
|
inputs.append(
|
||||||
|
ctx.push_data_input(
|
||||||
|
name,
|
||||||
|
"shape",
|
||||||
|
TensorProto.INT32,
|
||||||
|
[len(shape)],
|
||||||
|
shape,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
|
elif ty == backend.OpType.Concat:
|
||||||
|
axis = backend.concat_axis_of(op)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
|
elif ty == backend.OpType.Gather:
|
||||||
|
axis = backend.gather_axis_of(op)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
|
elif ty == backend.OpType.ReduceMean:
|
||||||
|
axes = backend.reduce_mean_axes_of(op)
|
||||||
|
inputs.append(
|
||||||
|
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
|
||||||
|
)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
|
||||||
|
elif ty == backend.OpType.Slice:
|
||||||
|
raise Exception("TODO")
|
||||||
|
elif ty == backend.OpType.Pad:
|
||||||
|
raise Exception("TODO")
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||||
|
|
||||||
|
return ctx.build(name)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_onnx(model: ModelProto):
|
||||||
print()
|
print()
|
||||||
|
|
||||||
for field in [
|
for field in [
|
||||||
|
@ -339,34 +518,32 @@ def parse_onnx(model: onnx.ModelProto):
|
||||||
print(" {}".format(node.name))
|
print(" {}".format(node.name))
|
||||||
|
|
||||||
|
|
||||||
def _parse_attribute(
|
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||||
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
for attr in node.attribute:
|
for attr in node.attribute:
|
||||||
if attr.name in attrs:
|
if attr.name in attrs:
|
||||||
if attr.type == onnx.AttributeProto.INT:
|
if attr.type == AttributeProto.INT:
|
||||||
attrs[attr.name] = attr.i
|
attrs[attr.name] = attr.i
|
||||||
elif attr.type == onnx.AttributeProto.INTS:
|
elif attr.type == AttributeProto.INTS:
|
||||||
attrs[attr.name] = attr.ints
|
attrs[attr.name] = attr.ints
|
||||||
elif attr.type == onnx.AttributeProto.FLOAT:
|
elif attr.type == AttributeProto.FLOAT:
|
||||||
attrs[attr.name] = attr.f
|
attrs[attr.name] = attr.f
|
||||||
elif attr.type == onnx.AttributeProto.STRING:
|
elif attr.type == AttributeProto.STRING:
|
||||||
attrs[attr.name] = attr.s
|
attrs[attr.name] = attr.s
|
||||||
elif attr.type == onnx.AttributeProto.TENSOR:
|
elif attr.type == AttributeProto.TENSOR:
|
||||||
attrs[attr.name] = attr.t
|
attrs[attr.name] = attr.t
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
|
def _parse_data(tensor: TensorProto) -> List[int]:
|
||||||
if tensor.data_type == onnx.TensorProto.INT32:
|
if tensor.data_type == TensorProto.INT32:
|
||||||
return [int(i) for i in tensor.int32_data]
|
return [int(i) for i in tensor.int32_data]
|
||||||
elif tensor.data_type == onnx.TensorProto.INT64:
|
elif tensor.data_type == TensorProto.INT64:
|
||||||
return [int(i) for i in tensor.int64_data]
|
return [int(i) for i in tensor.int64_data]
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||||
|
|
||||||
|
|
||||||
def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]:
|
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||||
|
|
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
||||||
make_tensor_value_info,
|
make_tensor_value_info,
|
||||||
)
|
)
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
|
||||||
|
|
||||||
|
|
||||||
def make_and_import_model(graph: onnx.GraphProto):
|
def make_and_import_model(graph: onnx.GraphProto):
|
||||||
|
@ -293,11 +293,20 @@ class TestStringMethods(unittest.TestCase):
|
||||||
parse_onnx(model)
|
parse_onnx(model)
|
||||||
|
|
||||||
def test_frontend(self):
|
def test_frontend(self):
|
||||||
handler = backend.GraphHandlerObj(runtime)
|
handler = backend.GraphHandler(runtime)
|
||||||
i = handler.tensor([1, 2, 3], 12)
|
a = handler.tensor([1, 2, 3], 12)
|
||||||
w = handler.tensor([1, 3, 4], 12)
|
b = handler.tensor([1, 2, 3], 12)
|
||||||
o = handler.tensor([1, 2, 4], 12)
|
c = handler.tensor([1, 2, 3], 12)
|
||||||
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
|
d = handler.tensor([1, 2, 3], 12)
|
||||||
|
e = handler.tensor([1, 2, 3], 12)
|
||||||
|
|
||||||
|
x = handler.add(
|
||||||
|
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
|
||||||
|
)
|
||||||
|
y = handler.tensor([3, 2, 1], 12)
|
||||||
|
handler.reshape(x, y, [3, 2, 1])
|
||||||
|
|
||||||
|
to_onnx(handler, "test_frontend")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
|
#include <algorithm>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
||||||
|
: runtime(runtime), sorted(false) {
|
||||||
map<UidBaseType, Tensor> tensorPool;
|
map<UidBaseType, Tensor> tensorPool;
|
||||||
// Clone tensors
|
// Clone tensors
|
||||||
for (const auto &op : ops_in) {
|
for (const auto &op : ops_in) {
|
||||||
|
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||||
|
sorted = false;
|
||||||
ops.push_back(op);
|
ops.push_back(op);
|
||||||
for (auto &input : op->getInputs()) {
|
for (auto &input : op->getInputs()) {
|
||||||
input->addInputOf(op);
|
input->addInputOf(op);
|
||||||
|
@ -66,6 +69,53 @@ string GraphObj::toString() const {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GraphObj::topo_sort() {
|
||||||
|
if (this->sorted)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// std::unordered_set<Tensor> inputs;
|
||||||
|
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||||
|
std::vector<Operator> sorted;
|
||||||
|
|
||||||
|
while (!waiting.empty()) {
|
||||||
|
// Any node is move to sorted in this loop.
|
||||||
|
auto modified = false;
|
||||||
|
// Find head nodes.
|
||||||
|
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||||
|
const auto &this_inputs = (*it)->getInputs();
|
||||||
|
// If none of the input tensors is in waiting list,
|
||||||
|
// this node is a head node.
|
||||||
|
const auto is_head = std::all_of(
|
||||||
|
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||||
|
auto src = input->getOutputOf();
|
||||||
|
return src // If the source node is in the waiting list,
|
||||||
|
// means that this node is not the head node.
|
||||||
|
? waiting.find(src) == waiting.end()
|
||||||
|
// This tensor has no source node,
|
||||||
|
// it must be a input tensor.
|
||||||
|
: (/*inputs.insert(input),*/ true);
|
||||||
|
});
|
||||||
|
// Moves head node to sorted.
|
||||||
|
if (is_head) {
|
||||||
|
modified = true;
|
||||||
|
sorted.emplace_back(std::move(*it));
|
||||||
|
it = waiting.erase(it);
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Waiting list never modifies during a pass,
|
||||||
|
// sorting fails.
|
||||||
|
if (!modified) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done.
|
||||||
|
this->ops = std::move(sorted);
|
||||||
|
return this->sorted = true;
|
||||||
|
}
|
||||||
|
|
||||||
void GraphObj::dataMalloc() {
|
void GraphObj::dataMalloc() {
|
||||||
for (auto &tensor : tensors) {
|
for (auto &tensor : tensors) {
|
||||||
tensor->dataMalloc();
|
tensor->dataMalloc();
|
||||||
|
@ -73,15 +123,12 @@ void GraphObj::dataMalloc() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||||
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime);
|
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
|
||||||
tensors.emplace_back(tensor);
|
|
||||||
return tensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
||||||
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
|
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
|
||||||
tensors.emplace_back(tensor);
|
return tensors.emplace_back(tensor);
|
||||||
return tensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
||||||
|
@ -98,4 +145,4 @@ OpVec GraphObj::getComputeOps() const {
|
||||||
return opList;
|
return opList;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
|
#include "operators/concat.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
|
#include "operators/reshape.h"
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
|
@ -21,88 +25,151 @@ void register_operator_timer(py::module &m) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void export_values(py::module &m) {
|
||||||
|
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
|
||||||
|
|
||||||
|
py::enum_<ActType>(m, "ActType")
|
||||||
|
.value("Linear", ActType::None) // `None` is Python keyword
|
||||||
|
.VALUE(ActType, Relu)
|
||||||
|
.VALUE(ActType, Sigmoid)
|
||||||
|
.VALUE(ActType, Tanh)
|
||||||
|
.export_values();
|
||||||
|
|
||||||
|
py::enum_<OpType>(m, "OpType")
|
||||||
|
.VALUE(OpType, Unknown)
|
||||||
|
.VALUE(OpType, Conv)
|
||||||
|
.VALUE(OpType, Matmul)
|
||||||
|
.VALUE(OpType, ConvTrans)
|
||||||
|
.VALUE(OpType, G2BMM)
|
||||||
|
.VALUE(OpType, GBMM)
|
||||||
|
.VALUE(OpType, Pad)
|
||||||
|
.VALUE(OpType, Slice)
|
||||||
|
.VALUE(OpType, Concat)
|
||||||
|
.VALUE(OpType, Split)
|
||||||
|
.VALUE(OpType, Transpose)
|
||||||
|
.VALUE(OpType, Extend)
|
||||||
|
.VALUE(OpType, MaxPool)
|
||||||
|
.VALUE(OpType, AvgPool)
|
||||||
|
.VALUE(OpType, Add)
|
||||||
|
.VALUE(OpType, Sub)
|
||||||
|
.VALUE(OpType, Mul)
|
||||||
|
.VALUE(OpType, Div)
|
||||||
|
.VALUE(OpType, Pow)
|
||||||
|
.VALUE(OpType, Gather)
|
||||||
|
.VALUE(OpType, ReduceMean)
|
||||||
|
.VALUE(OpType, Reshape)
|
||||||
|
.VALUE(OpType, Flatten)
|
||||||
|
.VALUE(OpType, Identity)
|
||||||
|
.VALUE(OpType, BatchNorm)
|
||||||
|
.VALUE(OpType, Softmax)
|
||||||
|
.VALUE(OpType, Activation)
|
||||||
|
.VALUE(OpType, Relu)
|
||||||
|
.VALUE(OpType, Sigmoid)
|
||||||
|
.VALUE(OpType, Tanh)
|
||||||
|
.VALUE(OpType, Abs)
|
||||||
|
.VALUE(OpType, Resize)
|
||||||
|
.VALUE(OpType, MemBound)
|
||||||
|
.export_values();
|
||||||
|
|
||||||
|
#undef VALUE
|
||||||
|
}
|
||||||
|
|
||||||
|
static int tensor_dtype(Tensor t) {
|
||||||
|
if (t->getDType() == DataType::Float32)
|
||||||
|
return OnnxDType::FLOAT;
|
||||||
|
if (t->getDType() == DataType::UInt32)
|
||||||
|
return OnnxDType::UINT32;
|
||||||
|
if (t->getDType() == DataType::UInt8)
|
||||||
|
return OnnxDType::UINT8;
|
||||||
|
if (t->getDType() == DataType::Int8)
|
||||||
|
return OnnxDType::INT8;
|
||||||
|
if (t->getDType() == DataType::UInt16)
|
||||||
|
return OnnxDType::UINT16;
|
||||||
|
if (t->getDType() == DataType::Int16)
|
||||||
|
return OnnxDType::INT16;
|
||||||
|
if (t->getDType() == DataType::Int32)
|
||||||
|
return OnnxDType::INT32;
|
||||||
|
if (t->getDType() == DataType::Int64)
|
||||||
|
return OnnxDType::INT64;
|
||||||
|
IT_ASSERT(false, "Unsupported data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
static int concat_axis_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||||
|
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int gather_axis_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Gather);
|
||||||
|
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
|
||||||
|
}
|
||||||
|
|
||||||
|
static vector<int> reduce_mean_axes_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
|
||||||
|
auto &set = dynamic_cast<const ReduceMeanObj *>(op.get())->getAxes();
|
||||||
|
return vector(set.begin(), set.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Shape reshape_shape_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||||
|
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_functions(py::module &m) {
|
||||||
|
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||||
|
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||||
|
.FUNCTION(tensor_dtype)
|
||||||
|
.FUNCTION(reshape_shape_of)
|
||||||
|
.FUNCTION(concat_axis_of)
|
||||||
|
.FUNCTION(gather_axis_of)
|
||||||
|
.FUNCTION(reduce_mean_axes_of);
|
||||||
|
#undef FUNCTION
|
||||||
|
}
|
||||||
|
|
||||||
void init_graph_builder(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
using Handler = GraphHandlerObj;
|
using Handler = GraphHandlerObj;
|
||||||
|
|
||||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
m, "CpuRuntimeObj");
|
m, "CpuRuntime");
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||||
py::enum_<ActType>(m, "ActType")
|
.def("shape", &TensorObj::getDims, policy::move)
|
||||||
.value("Linear", ActType::None) // `None` is Python keyword
|
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||||
.value("Relu", ActType::Relu)
|
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||||
.value("Sigmoid", ActType::Sigmoid)
|
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||||
.value("Tanh", ActType::Tanh)
|
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||||
.export_values();
|
policy::reference)
|
||||||
py::class_<Handler>(m, "GraphHandlerObj")
|
.def("outputs",
|
||||||
|
py::overload_cast<>(&OperatorObj::getOutputs, py::const_),
|
||||||
|
policy::reference);
|
||||||
|
py::class_<Handler>(m, "GraphHandler")
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
.def("tensor", &Handler::tensor, policy::move)
|
||||||
policy::move)
|
|
||||||
.def("conv", &Handler::conv, policy::move)
|
.def("conv", &Handler::conv, policy::move)
|
||||||
.def("matmul",
|
.def("matmul", &Handler::matmul, policy::move)
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
.def("batchNorm", &Handler::batchNorm, policy::move)
|
||||||
ActType>(&Handler::matmul),
|
.def("maxPool", &Handler::maxPool, policy::move)
|
||||||
policy::move)
|
.def("avgPool", &Handler::avgPool, policy::move)
|
||||||
.def("batchNorm",
|
.def("add", &Handler::add, policy::move)
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
.def("sub", &Handler::sub, policy::move)
|
||||||
float, float, bool>(&Handler::batchNorm),
|
.def("mul", &Handler::mul, policy::move)
|
||||||
policy::move)
|
.def("div", &Handler::div, policy::move)
|
||||||
.def("maxPool",
|
.def("pow", &Handler::pow, policy::move)
|
||||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
.def("relu", &Handler::relu, policy::move)
|
||||||
int, int>(&Handler::maxPool),
|
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||||
policy::move)
|
.def("tanh", &Handler::tanh, policy::move)
|
||||||
.def("avgPool",
|
.def("softmax", &Handler::softmax, policy::move)
|
||||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
.def("abs", &Handler::abs, policy::move)
|
||||||
int, int>(&Handler::avgPool),
|
.def("identity", &Handler::identity, policy::move)
|
||||||
policy::move)
|
.def("flatten", &Handler::flatten, policy::move)
|
||||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
.def("reshape", &Handler::reshape, policy::move)
|
||||||
policy::move)
|
.def("concat", &Handler::concat, policy::move)
|
||||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
.def("gather", &Handler::gather, policy::move)
|
||||||
policy::move)
|
.def("reduceMean", &Handler::reduceMean, policy::move)
|
||||||
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
|
.def("slice", &Handler::slice, policy::move)
|
||||||
policy::move)
|
.def("pad", &Handler::pad, policy::move)
|
||||||
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
policy::move)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
|
|
||||||
policy::move)
|
|
||||||
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
|
|
||||||
policy::move)
|
|
||||||
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
|
|
||||||
policy::move)
|
|
||||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
|
|
||||||
policy::move)
|
|
||||||
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
|
|
||||||
policy::move)
|
|
||||||
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
|
||||||
policy::move)
|
|
||||||
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
|
||||||
policy::move)
|
|
||||||
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
|
|
||||||
policy::move)
|
|
||||||
.def("reshape",
|
|
||||||
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
|
|
||||||
policy::move)
|
|
||||||
.def("concat",
|
|
||||||
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
|
|
||||||
policy::move)
|
|
||||||
.def("gather",
|
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
|
|
||||||
policy::move)
|
|
||||||
.def("reduceMean",
|
|
||||||
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
|
||||||
bool>(&Handler::reduceMean),
|
|
||||||
policy::move)
|
|
||||||
.def("slice",
|
|
||||||
py::overload_cast<
|
|
||||||
Tensor, Tensor, const vector<int> &, const vector<int> &,
|
|
||||||
const optional<vector<int>> &, const optional<vector<int>> &>(
|
|
||||||
&Handler::slice),
|
|
||||||
policy::move)
|
|
||||||
.def("pad",
|
|
||||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
|
||||||
const optional<vector<int>> &>(&Handler::pad),
|
|
||||||
policy::move)
|
|
||||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||||
.def("run", &Handler::run, policy::automatic);
|
.def("run", &Handler::run, policy::automatic);
|
||||||
}
|
}
|
||||||
|
@ -111,5 +178,7 @@ void init_graph_builder(py::module &m) {
|
||||||
|
|
||||||
PYBIND11_MODULE(backend, m) {
|
PYBIND11_MODULE(backend, m) {
|
||||||
infini::register_operator_timer(m);
|
infini::register_operator_timer(m);
|
||||||
|
infini::export_values(m);
|
||||||
|
infini::export_functions(m);
|
||||||
infini::init_graph_builder(m);
|
infini::init_graph_builder(m);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/blob.h"
|
#include "core/blob.h"
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
|
||||||
EXPECT_TRUE(o0->equalData(ans));
|
EXPECT_TRUE(o0->equalData(ans));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Graph, topological) {
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
|
||||||
|
auto ops = std::vector{
|
||||||
|
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
|
||||||
|
g->addOpWithOutputs<AddObj>(abc, d, abcd),
|
||||||
|
g->addOpWithOutputs<AddObj>(ab, c, abc),
|
||||||
|
g->addOpWithOutputs<AddObj>(a, b, ab),
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto p = ops.begin();
|
||||||
|
auto q = g->getOperators().begin();
|
||||||
|
while (p != ops.end()) {
|
||||||
|
EXPECT_EQ(*p++, *q++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_TRUE(g->topo_sort());
|
||||||
|
|
||||||
|
{
|
||||||
|
auto p = ops.rbegin();
|
||||||
|
auto q = g->getOperators().begin();
|
||||||
|
while (p != ops.rend()) {
|
||||||
|
EXPECT_EQ(*p++, *q++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
TEST(Graph, perf_engine) {
|
TEST(Graph, perf_engine) {
|
||||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(runtime);
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
|
Loading…
Reference in New Issue