wip: onnx 导出 (#65)

| Notice | Work in progress
|-|-

> based on #63 

## Progress

1. [x] 对节点拓扑排序
2. [x] 遍历节点,命名并导出其输出张量(`make_tensor_value_info`)
3. [x] 识别图的输入张量,命名并导出(`make_tensor_value_info`)
4. [x] 根据节点类型,识别权重及属性,导出节点(`make_node`)
5. [x] `make_graph` -> `check_graph` -> `make_model` -> `check_model`
This commit is contained in:
Haojie Wang 2023-03-15 15:22:09 +08:00 committed by GitHub
commit dd5d091dbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 492 additions and 128 deletions

View File

@ -8,13 +8,10 @@ class GraphObj : public Object {
protected: protected:
Runtime runtime; Runtime runtime;
TensorVec tensors; TensorVec tensors;
// TODO: whether to record input and output tensors
// TensorVec inputs;
// TensorVec outputs;
OpVec ops; OpVec ops;
public: public:
GraphObj(Runtime runtime) : runtime(runtime){}; explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
GraphObj(Runtime runtime, OpVec ops_in); GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override; string toString() const override;
Runtime getRuntime() const { return runtime; } Runtime getRuntime() const { return runtime; }
@ -23,10 +20,23 @@ class GraphObj : public Object {
Tensor addTensor(const Tensor &tensor); Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors); TensorVec addTensor(const TensorVec &tensors);
Tensor cloneTensor(const Tensor &tensor) { Tensor cloneTensor(const Tensor &tensor) {
auto ret = addTensor(tensor->clone(runtime)); return addTensor(tensor->clone(runtime));
return ret;
} }
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
/**
* Sort the nodes in topological order.
* It returns true if the sorting is successful.
* Otherwise false is returned, means that there are rings in the graph,
* so the topological sorting fails.
*/
bool topo_sort();
void dataMalloc();
/** /**
* @brief Add an operator and create its outputs. Output tensor arguments * @brief Add an operator and create its outputs. Output tensor arguments
* should be empty Refs (e.g., nullptr). * should be empty Refs (e.g., nullptr).
@ -47,25 +57,27 @@ class GraphObj : public Object {
return op; return op;
} }
const TensorVec &getTensors() const { return tensors; } /**
const TensorVec getInputs() const { * @brief Gets input tensors of this graph.
*/
inline TensorVec getInputs() const {
TensorVec ret; TensorVec ret;
for (auto t : tensors) for (const auto &t : tensors)
if (!t->getOutputOf()) if (!t->getOutputOf())
ret.emplace_back(t); ret.emplace_back(t);
return ret; return ret;
} }
const TensorVec getOutputs() const {
/**
* @brief Gets output tensors of this graph.
*/
inline TensorVec getOutputs() const {
TensorVec ret; TensorVec ret;
for (auto t : tensors) for (const auto &t : tensors)
if (t->getInputOf().empty()) if (t->getInputOf().empty())
ret.emplace_back(t); ret.emplace_back(t);
return ret; return ret;
} }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
void dataMalloc();
private: private:
/** /**
@ -73,9 +85,10 @@ class GraphObj : public Object {
*/ */
void addOperatorAndConnect(const Operator &op); void addOperatorAndConnect(const Operator &op);
// TODO: move to another class /**
// bool exportOnnx(const char *path); * @brief If the nodes is sorted in topological order.
// bool importOnnx(const char *net); */
bool sorted;
}; };
} // namespace infini } // namespace infini

View File

@ -38,12 +38,14 @@ class GraphHandlerObj {
Tensor tensor(Shape dims, int dtype); Tensor tensor(Shape dims, int dtype);
//------ operators
inline OpVec operators() { return g->getOperators(); }
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw, Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
int sh, int sw, int dh, int dw); int sh, int sw, int dh, int dw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act); Tensor bias, ActType act);
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
Tensor scale, Tensor bias, float momentum, float eps, Tensor scale, Tensor bias, float momentum, float eps,
bool training); bool training);
@ -77,6 +79,10 @@ class GraphHandlerObj {
Tensor pad(Tensor input, Tensor output, const vector<int> &pads, Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes); const optional<vector<int>> &axes);
//------ modifiers
inline bool topo_sort() { return g->topo_sort(); }
//------ runtime //------ runtime
inline void data_malloc() { g->dataMalloc(); } inline void data_malloc() { g->dataMalloc(); }

View File

@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj {
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
bool isReduced(int idx) const; bool isReduced(int idx) const;
const set<int> &getAxes() const { return axes; }
bool getKeepDims() const { return keepDims; } bool getKeepDims() const { return keepDims; }
private: private:

View File

@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
inline Shape getShape() const { return dims; }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;

View File

@ -1,17 +1,39 @@
import onnx, backend import backend
from onnx import (
ModelProto,
TensorProto,
NodeProto,
AttributeProto,
TensorShapeProto,
ValueInfoProto,
)
from onnx.helper import (
make_node,
make_tensor_value_info,
make_tensor,
make_graph,
make_model,
)
from onnx.checker import (
check_graph,
check_model,
check_node,
check_value_info,
check_tensor,
)
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any from typing import Dict, List, Any, Tuple, Sequence
from functools import reduce from functools import reduce
runtime = backend.cpu_runtime() runtime = backend.cpu_runtime()
def from_onnx(model: onnx.ModelProto): def from_onnx(model: ModelProto) -> backend.GraphHandler:
model = infer_shapes(model) model = infer_shapes(model)
handler = backend.GraphHandlerObj(runtime) handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.TensorObj] = dict() tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, onnx.TensorProto] = dict() data: Dict[str, TensorProto] = dict()
for input in model.graph.input: for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape) dims = _take_shape_dim(input.type.tensor_type.shape)
@ -303,7 +325,164 @@ def from_onnx(model: onnx.ModelProto):
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
def parse_onnx(model: onnx.ModelProto): def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
class Context:
# saves object names, including tensors and operators
names: Dict[Any, str] = dict()
# counts the occurrence times of each operator for naming
count_op: Dict[backend.OpType, int] = dict()
# counts input and output tensors for naming
count_in, count_out = 0, 0
# saves nodes (operators)
nodes: List[NodeProto] = []
# saves global input tensors
inputs: List[ValueInfoProto] = []
# saves global output tensors
outputs: List[ValueInfoProto] = []
# saves global input tensors
initializers: List[TensorProto] = []
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type()
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
self.names[op] = name
self.count_op[ty] += 1
return ty, name
def push_output(self, name: str, tensor: backend.Tensor) -> str:
self.names[tensor] = name
return name
def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor)
# means that this input is a global input
if name is None:
self.count_in += 1
name = "input{}".format(self.count_in)
self.names[tensor] = name
shape = tensor.shape()
dtype = backend.tensor_dtype(tensor)
value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
self.inputs.append(value_info)
return name
def push_data_input(
self,
node_name: str,
attr_name: str,
elem_type: int,
shape: Sequence[int],
vals: Any,
) -> str:
name = "{}_{}".format(node_name, attr_name)
value_info = make_tensor_value_info(name, elem_type, shape)
tensor = make_tensor(name, elem_type, shape, vals)
check_value_info(value_info)
check_tensor(tensor)
self.inputs.append(value_info)
self.initializers.append(tensor)
return name
def push_node(self, node: NodeProto) -> None:
check_node(node)
self.nodes.append(node)
def build(self, name: str) -> ModelProto:
print()
print(ctx.names)
print()
print(ctx.inputs)
print()
print(ctx.outputs)
print()
print(ctx.nodes)
graph = make_graph(
self.nodes, name, self.inputs, self.outputs, self.initializers
)
check_graph(graph)
model = make_model(graph)
check_model(model)
return model
# 拓扑排序
if not graph.topo_sort():
raise Exception("Sorting fails")
ops = graph.operators() # 图中所有算子(节点)
ctx = Context()
for op in ops:
ty, name = ctx.name_op(op)
inputs = [ctx.push_input(it) for it in op.inputs()]
outputs = [
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Matmul:
ctx.push_node(make_node("MatMul", inputs, outputs, name))
elif ty == backend.OpType.BatchNorm:
raise Exception("TODO")
elif ty == backend.OpType.MaxPool:
raise Exception("TODO")
elif ty == backend.OpType.AvgPool:
raise Exception("TODO")
elif ty in [
backend.OpType.Add,
backend.OpType.Sub,
backend.OpType.Mul,
backend.OpType.Div,
backend.OpType.Pow,
backend.OpType.Relu,
backend.OpType.Sigmoid,
backend.OpType.Tanh,
backend.OpType.Softmax,
backend.OpType.Abs,
backend.OpType.Identity,
]:
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
elif ty == backend.OpType.Reshape:
shape = backend.reshape_shape_of(op)
inputs.append(
ctx.push_data_input(
name,
"shape",
TensorProto.INT32,
[len(shape)],
shape,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Concat:
axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Gather:
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean:
axes = backend.reduce_mean_axes_of(op)
inputs.append(
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
elif ty == backend.OpType.Slice:
raise Exception("TODO")
elif ty == backend.OpType.Pad:
raise Exception("TODO")
else:
raise Exception("Unsupported OpType {}".format(ty.name))
return ctx.build(name)
def parse_onnx(model: ModelProto):
print() print()
for field in [ for field in [
@ -339,34 +518,32 @@ def parse_onnx(model: onnx.ModelProto):
print(" {}".format(node.name)) print(" {}".format(node.name))
def _parse_attribute( def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
) -> Dict[str, Any]:
for attr in node.attribute: for attr in node.attribute:
if attr.name in attrs: if attr.name in attrs:
if attr.type == onnx.AttributeProto.INT: if attr.type == AttributeProto.INT:
attrs[attr.name] = attr.i attrs[attr.name] = attr.i
elif attr.type == onnx.AttributeProto.INTS: elif attr.type == AttributeProto.INTS:
attrs[attr.name] = attr.ints attrs[attr.name] = attr.ints
elif attr.type == onnx.AttributeProto.FLOAT: elif attr.type == AttributeProto.FLOAT:
attrs[attr.name] = attr.f attrs[attr.name] = attr.f
elif attr.type == onnx.AttributeProto.STRING: elif attr.type == AttributeProto.STRING:
attrs[attr.name] = attr.s attrs[attr.name] = attr.s
elif attr.type == onnx.AttributeProto.TENSOR: elif attr.type == AttributeProto.TENSOR:
attrs[attr.name] = attr.t attrs[attr.name] = attr.t
else: else:
assert False, "Unsupported Attribute Type: {}".format(attr.type) assert False, "Unsupported Attribute Type: {}".format(attr.type)
return attrs return attrs
def _parse_data(tensor: onnx.TensorProto) -> List[int]: def _parse_data(tensor: TensorProto) -> List[int]:
if tensor.data_type == onnx.TensorProto.INT32: if tensor.data_type == TensorProto.INT32:
return [int(i) for i in tensor.int32_data] return [int(i) for i in tensor.int32_data]
elif tensor.data_type == onnx.TensorProto.INT64: elif tensor.data_type == TensorProto.INT64:
return [int(i) for i in tensor.int64_data] return [int(i) for i in tensor.int64_data]
else: else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]: def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]

View File

@ -8,7 +8,7 @@ from onnx.helper import (
make_tensor_value_info, make_tensor_value_info,
) )
from onnx.checker import check_model from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
def make_and_import_model(graph: onnx.GraphProto): def make_and_import_model(graph: onnx.GraphProto):
@ -293,11 +293,20 @@ class TestStringMethods(unittest.TestCase):
parse_onnx(model) parse_onnx(model)
def test_frontend(self): def test_frontend(self):
handler = backend.GraphHandlerObj(runtime) handler = backend.GraphHandler(runtime)
i = handler.tensor([1, 2, 3], 12) a = handler.tensor([1, 2, 3], 12)
w = handler.tensor([1, 3, 4], 12) b = handler.tensor([1, 2, 3], 12)
o = handler.tensor([1, 2, 4], 12) c = handler.tensor([1, 2, 3], 12)
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu) d = handler.tensor([1, 2, 3], 12)
e = handler.tensor([1, 2, 3], 12)
x = handler.add(
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
)
y = handler.tensor([3, 2, 1], 12)
handler.reshape(x, y, [3, 2, 1])
to_onnx(handler, "test_frontend")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,9 +1,11 @@
#include "core/graph.h" #include "core/graph.h"
#include <algorithm>
#include <queue> #include <queue>
namespace infini { namespace infini {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) { GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
: runtime(runtime), sorted(false) {
map<UidBaseType, Tensor> tensorPool; map<UidBaseType, Tensor> tensorPool;
// Clone tensors // Clone tensors
for (const auto &op : ops_in) { for (const auto &op : ops_in) {
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
} }
void GraphObj::addOperatorAndConnect(const Operator &op) { void GraphObj::addOperatorAndConnect(const Operator &op) {
sorted = false;
ops.push_back(op); ops.push_back(op);
for (auto &input : op->getInputs()) { for (auto &input : op->getInputs()) {
input->addInputOf(op); input->addInputOf(op);
@ -66,6 +69,53 @@ string GraphObj::toString() const {
return oss.str(); return oss.str();
} }
bool GraphObj::topo_sort() {
if (this->sorted)
return true;
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
std::vector<Operator> sorted;
while (!waiting.empty()) {
// Any node is move to sorted in this loop.
auto modified = false;
// Find head nodes.
for (auto it = waiting.begin(); it != waiting.end();) {
const auto &this_inputs = (*it)->getInputs();
// If none of the input tensors is in waiting list,
// this node is a head node.
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getOutputOf();
return src // If the source node is in the waiting list,
// means that this node is not the head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
modified = true;
sorted.emplace_back(std::move(*it));
it = waiting.erase(it);
} else {
++it;
}
}
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
}
void GraphObj::dataMalloc() { void GraphObj::dataMalloc() {
for (auto &tensor : tensors) { for (auto &tensor : tensors) {
tensor->dataMalloc(); tensor->dataMalloc();
@ -73,15 +123,12 @@ void GraphObj::dataMalloc() {
} }
Tensor GraphObj::addTensor(Shape dim, DataType dtype) { Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime); return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
tensors.emplace_back(tensor);
return tensor;
} }
Tensor GraphObj::addTensor(const Tensor &tensor) { Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch"); IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
tensors.emplace_back(tensor); return tensors.emplace_back(tensor);
return tensor;
} }
TensorVec GraphObj::addTensor(const TensorVec &tensors) { TensorVec GraphObj::addTensor(const TensorVec &tensors) {
@ -98,4 +145,4 @@ OpVec GraphObj::getComputeOps() const {
return opList; return opList;
}; };
} // namespace infini } // namespace infini

View File

@ -1,4 +1,8 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/concat.h"
#include "operators/gather.h"
#include "operators/reduce_mean.h"
#include "operators/reshape.h"
#include <pybind11/stl.h> #include <pybind11/stl.h>
#ifdef USE_CUDA #ifdef USE_CUDA
@ -21,88 +25,151 @@ void register_operator_timer(py::module &m) {
#endif #endif
} }
void export_values(py::module &m) {
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
py::enum_<ActType>(m, "ActType")
.value("Linear", ActType::None) // `None` is Python keyword
.VALUE(ActType, Relu)
.VALUE(ActType, Sigmoid)
.VALUE(ActType, Tanh)
.export_values();
py::enum_<OpType>(m, "OpType")
.VALUE(OpType, Unknown)
.VALUE(OpType, Conv)
.VALUE(OpType, Matmul)
.VALUE(OpType, ConvTrans)
.VALUE(OpType, G2BMM)
.VALUE(OpType, GBMM)
.VALUE(OpType, Pad)
.VALUE(OpType, Slice)
.VALUE(OpType, Concat)
.VALUE(OpType, Split)
.VALUE(OpType, Transpose)
.VALUE(OpType, Extend)
.VALUE(OpType, MaxPool)
.VALUE(OpType, AvgPool)
.VALUE(OpType, Add)
.VALUE(OpType, Sub)
.VALUE(OpType, Mul)
.VALUE(OpType, Div)
.VALUE(OpType, Pow)
.VALUE(OpType, Gather)
.VALUE(OpType, ReduceMean)
.VALUE(OpType, Reshape)
.VALUE(OpType, Flatten)
.VALUE(OpType, Identity)
.VALUE(OpType, BatchNorm)
.VALUE(OpType, Softmax)
.VALUE(OpType, Activation)
.VALUE(OpType, Relu)
.VALUE(OpType, Sigmoid)
.VALUE(OpType, Tanh)
.VALUE(OpType, Abs)
.VALUE(OpType, Resize)
.VALUE(OpType, MemBound)
.export_values();
#undef VALUE
}
static int tensor_dtype(Tensor t) {
if (t->getDType() == DataType::Float32)
return OnnxDType::FLOAT;
if (t->getDType() == DataType::UInt32)
return OnnxDType::UINT32;
if (t->getDType() == DataType::UInt8)
return OnnxDType::UINT8;
if (t->getDType() == DataType::Int8)
return OnnxDType::INT8;
if (t->getDType() == DataType::UInt16)
return OnnxDType::UINT16;
if (t->getDType() == DataType::Int16)
return OnnxDType::INT16;
if (t->getDType() == DataType::Int32)
return OnnxDType::INT32;
if (t->getDType() == DataType::Int64)
return OnnxDType::INT64;
IT_ASSERT(false, "Unsupported data type");
}
static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat);
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
}
static int gather_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Gather);
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
}
static vector<int> reduce_mean_axes_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
auto &set = dynamic_cast<const ReduceMeanObj *>(op.get())->getAxes();
return vector(set.begin(), set.end());
}
static Shape reshape_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Reshape);
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
}
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of)
.FUNCTION(gather_axis_of)
.FUNCTION(reduce_mean_axes_of);
#undef FUNCTION
}
void init_graph_builder(py::module &m) { void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj; using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance); py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>( py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj"); m, "CpuRuntime");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj"); py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
py::enum_<ActType>(m, "ActType") .def("shape", &TensorObj::getDims, policy::move)
.value("Linear", ActType::None) // `None` is Python keyword .def("src", &TensorObj::getOutputOf, policy::move);
.value("Relu", ActType::Relu) py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.value("Sigmoid", ActType::Sigmoid) .def("op_type", &OperatorObj::getOpType, policy::automatic)
.value("Tanh", ActType::Tanh) .def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
.export_values(); policy::reference)
py::class_<Handler>(m, "GraphHandlerObj") .def("outputs",
py::overload_cast<>(&OperatorObj::getOutputs, py::const_),
policy::reference);
py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>()) .def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor), .def("tensor", &Handler::tensor, policy::move)
policy::move)
.def("conv", &Handler::conv, policy::move) .def("conv", &Handler::conv, policy::move)
.def("matmul", .def("matmul", &Handler::matmul, policy::move)
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor, .def("batchNorm", &Handler::batchNorm, policy::move)
ActType>(&Handler::matmul), .def("maxPool", &Handler::maxPool, policy::move)
policy::move) .def("avgPool", &Handler::avgPool, policy::move)
.def("batchNorm", .def("add", &Handler::add, policy::move)
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, .def("sub", &Handler::sub, policy::move)
float, float, bool>(&Handler::batchNorm), .def("mul", &Handler::mul, policy::move)
policy::move) .def("div", &Handler::div, policy::move)
.def("maxPool", .def("pow", &Handler::pow, policy::move)
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int, .def("relu", &Handler::relu, policy::move)
int, int>(&Handler::maxPool), .def("sigmoid", &Handler::sigmoid, policy::move)
policy::move) .def("tanh", &Handler::tanh, policy::move)
.def("avgPool", .def("softmax", &Handler::softmax, policy::move)
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int, .def("abs", &Handler::abs, policy::move)
int, int>(&Handler::avgPool), .def("identity", &Handler::identity, policy::move)
policy::move) .def("flatten", &Handler::flatten, policy::move)
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add), .def("reshape", &Handler::reshape, policy::move)
policy::move) .def("concat", &Handler::concat, policy::move)
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub), .def("gather", &Handler::gather, policy::move)
policy::move) .def("reduceMean", &Handler::reduceMean, policy::move)
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul), .def("slice", &Handler::slice, policy::move)
policy::move) .def("pad", &Handler::pad, policy::move)
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div), .def("topo_sort", &Handler::topo_sort, policy::automatic)
policy::move) .def("operators", &Handler::operators, policy::move)
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
policy::move)
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
policy::move)
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
policy::move)
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
policy::move)
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
policy::move)
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
policy::move)
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
policy::move)
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
policy::move)
.def("reshape",
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
policy::move)
.def("concat",
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
policy::move)
.def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
policy::move)
.def("reduceMean",
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
bool>(&Handler::reduceMean),
policy::move)
.def("slice",
py::overload_cast<
Tensor, Tensor, const vector<int> &, const vector<int> &,
const optional<vector<int>> &, const optional<vector<int>> &>(
&Handler::slice),
policy::move)
.def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<vector<int>> &>(&Handler::pad),
policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("run", &Handler::run, policy::automatic);
} }
@ -111,5 +178,7 @@ void init_graph_builder(py::module &m) {
PYBIND11_MODULE(backend, m) { PYBIND11_MODULE(backend, m) {
infini::register_operator_timer(m); infini::register_operator_timer(m);
infini::export_values(m);
infini::export_functions(m);
infini::init_graph_builder(m); infini::init_graph_builder(m);
} }

View File

@ -1,6 +1,7 @@
#include "core/blob.h" #include "core/blob.h"
#include "core/graph.h" #include "core/graph.h"
#include "core/runtime.h" #include "core/runtime.h"
#include "operators/element_wise.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/unary.h" #include "operators/unary.h"
#include "test.h" #include "test.h"
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
EXPECT_TRUE(o0->equalData(ans)); EXPECT_TRUE(o0->equalData(ans));
} }
TEST(Graph, topological) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
auto ops = std::vector{
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
g->addOpWithOutputs<AddObj>(abc, d, abcd),
g->addOpWithOutputs<AddObj>(ab, c, abc),
g->addOpWithOutputs<AddObj>(a, b, ab),
};
{
auto p = ops.begin();
auto q = g->getOperators().begin();
while (p != ops.end()) {
EXPECT_EQ(*p++, *q++);
}
}
EXPECT_TRUE(g->topo_sort());
{
auto p = ops.rbegin();
auto q = g->getOperators().begin();
while (p != ops.rend()) {
EXPECT_EQ(*p++, *q++);
}
}
} // namespace infini
TEST(Graph, perf_engine) { TEST(Graph, perf_engine) {
Runtime runtime = CpuRuntimeObj::getInstance(); Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime); Graph g = make_ref<GraphObj>(runtime);