wip: onnx 导出 (#65)

| Notice | Work in progress
|-|-

> based on #63 

## Progress

1. [x] 对节点拓扑排序
2. [x] 遍历节点,命名并导出其输出张量(`make_tensor_value_info`)
3. [x] 识别图的输入张量,命名并导出(`make_tensor_value_info`)
4. [x] 根据节点类型,识别权重及属性,导出节点(`make_node`)
5. [x] `make_graph` -> `check_graph` -> `make_model` -> `check_model`
This commit is contained in:
Haojie Wang 2023-03-15 15:22:09 +08:00 committed by GitHub
commit dd5d091dbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 492 additions and 128 deletions

View File

@ -8,13 +8,10 @@ class GraphObj : public Object {
protected:
Runtime runtime;
TensorVec tensors;
// TODO: whether to record input and output tensors
// TensorVec inputs;
// TensorVec outputs;
OpVec ops;
public:
GraphObj(Runtime runtime) : runtime(runtime){};
explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override;
Runtime getRuntime() const { return runtime; }
@ -23,10 +20,23 @@ class GraphObj : public Object {
Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors);
Tensor cloneTensor(const Tensor &tensor) {
auto ret = addTensor(tensor->clone(runtime));
return ret;
return addTensor(tensor->clone(runtime));
}
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
/**
* Sort the nodes in topological order.
* It returns true if the sorting is successful.
* Otherwise false is returned, means that there are rings in the graph,
* so the topological sorting fails.
*/
bool topo_sort();
void dataMalloc();
/**
* @brief Add an operator and create its outputs. Output tensor arguments
* should be empty Refs (e.g., nullptr).
@ -47,25 +57,27 @@ class GraphObj : public Object {
return op;
}
const TensorVec &getTensors() const { return tensors; }
const TensorVec getInputs() const {
/**
* @brief Gets input tensors of this graph.
*/
inline TensorVec getInputs() const {
TensorVec ret;
for (auto t : tensors)
for (const auto &t : tensors)
if (!t->getOutputOf())
ret.emplace_back(t);
return ret;
}
const TensorVec getOutputs() const {
/**
* @brief Gets output tensors of this graph.
*/
inline TensorVec getOutputs() const {
TensorVec ret;
for (auto t : tensors)
for (const auto &t : tensors)
if (t->getInputOf().empty())
ret.emplace_back(t);
return ret;
}
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
void dataMalloc();
private:
/**
@ -73,9 +85,10 @@ class GraphObj : public Object {
*/
void addOperatorAndConnect(const Operator &op);
// TODO: move to another class
// bool exportOnnx(const char *path);
// bool importOnnx(const char *net);
/**
* @brief If the nodes is sorted in topological order.
*/
bool sorted;
};
} // namespace infini

View File

@ -38,12 +38,14 @@ class GraphHandlerObj {
Tensor tensor(Shape dims, int dtype);
//------ operators
inline OpVec operators() { return g->getOperators(); }
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
int sh, int sw, int dh, int dw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act);
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
Tensor scale, Tensor bias, float momentum, float eps,
bool training);
@ -77,6 +79,10 @@ class GraphHandlerObj {
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes);
//------ modifiers
inline bool topo_sort() { return g->topo_sort(); }
//------ runtime
inline void data_malloc() { g->dataMalloc(); }

View File

@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj {
int numOutputs() const override { return 1; }
bool isReduced(int idx) const;
const set<int> &getAxes() const { return axes; }
bool getKeepDims() const { return keepDims; }
private:

View File

@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
inline Shape getShape() const { return dims; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;

View File

@ -1,17 +1,39 @@
import onnx, backend
import backend
from onnx import (
ModelProto,
TensorProto,
NodeProto,
AttributeProto,
TensorShapeProto,
ValueInfoProto,
)
from onnx.helper import (
make_node,
make_tensor_value_info,
make_tensor,
make_graph,
make_model,
)
from onnx.checker import (
check_graph,
check_model,
check_node,
check_value_info,
check_tensor,
)
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any
from typing import Dict, List, Any, Tuple, Sequence
from functools import reduce
runtime = backend.cpu_runtime()
def from_onnx(model: onnx.ModelProto):
def from_onnx(model: ModelProto) -> backend.GraphHandler:
model = infer_shapes(model)
handler = backend.GraphHandlerObj(runtime)
handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.TensorObj] = dict()
data: Dict[str, onnx.TensorProto] = dict()
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
@ -303,7 +325,164 @@ def from_onnx(model: onnx.ModelProto):
raise Exception('Unsupported operator "{}"'.format(node.op_type))
def parse_onnx(model: onnx.ModelProto):
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
class Context:
# saves object names, including tensors and operators
names: Dict[Any, str] = dict()
# counts the occurrence times of each operator for naming
count_op: Dict[backend.OpType, int] = dict()
# counts input and output tensors for naming
count_in, count_out = 0, 0
# saves nodes (operators)
nodes: List[NodeProto] = []
# saves global input tensors
inputs: List[ValueInfoProto] = []
# saves global output tensors
outputs: List[ValueInfoProto] = []
# saves global input tensors
initializers: List[TensorProto] = []
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type()
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
self.names[op] = name
self.count_op[ty] += 1
return ty, name
def push_output(self, name: str, tensor: backend.Tensor) -> str:
self.names[tensor] = name
return name
def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor)
# means that this input is a global input
if name is None:
self.count_in += 1
name = "input{}".format(self.count_in)
self.names[tensor] = name
shape = tensor.shape()
dtype = backend.tensor_dtype(tensor)
value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
self.inputs.append(value_info)
return name
def push_data_input(
self,
node_name: str,
attr_name: str,
elem_type: int,
shape: Sequence[int],
vals: Any,
) -> str:
name = "{}_{}".format(node_name, attr_name)
value_info = make_tensor_value_info(name, elem_type, shape)
tensor = make_tensor(name, elem_type, shape, vals)
check_value_info(value_info)
check_tensor(tensor)
self.inputs.append(value_info)
self.initializers.append(tensor)
return name
def push_node(self, node: NodeProto) -> None:
check_node(node)
self.nodes.append(node)
def build(self, name: str) -> ModelProto:
print()
print(ctx.names)
print()
print(ctx.inputs)
print()
print(ctx.outputs)
print()
print(ctx.nodes)
graph = make_graph(
self.nodes, name, self.inputs, self.outputs, self.initializers
)
check_graph(graph)
model = make_model(graph)
check_model(model)
return model
# 拓扑排序
if not graph.topo_sort():
raise Exception("Sorting fails")
ops = graph.operators() # 图中所有算子(节点)
ctx = Context()
for op in ops:
ty, name = ctx.name_op(op)
inputs = [ctx.push_input(it) for it in op.inputs()]
outputs = [
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Matmul:
ctx.push_node(make_node("MatMul", inputs, outputs, name))
elif ty == backend.OpType.BatchNorm:
raise Exception("TODO")
elif ty == backend.OpType.MaxPool:
raise Exception("TODO")
elif ty == backend.OpType.AvgPool:
raise Exception("TODO")
elif ty in [
backend.OpType.Add,
backend.OpType.Sub,
backend.OpType.Mul,
backend.OpType.Div,
backend.OpType.Pow,
backend.OpType.Relu,
backend.OpType.Sigmoid,
backend.OpType.Tanh,
backend.OpType.Softmax,
backend.OpType.Abs,
backend.OpType.Identity,
]:
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
elif ty == backend.OpType.Reshape:
shape = backend.reshape_shape_of(op)
inputs.append(
ctx.push_data_input(
name,
"shape",
TensorProto.INT32,
[len(shape)],
shape,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Concat:
axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Gather:
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean:
axes = backend.reduce_mean_axes_of(op)
inputs.append(
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
elif ty == backend.OpType.Slice:
raise Exception("TODO")
elif ty == backend.OpType.Pad:
raise Exception("TODO")
else:
raise Exception("Unsupported OpType {}".format(ty.name))
return ctx.build(name)
def parse_onnx(model: ModelProto):
print()
for field in [
@ -339,34 +518,32 @@ def parse_onnx(model: onnx.ModelProto):
print(" {}".format(node.name))
def _parse_attribute(
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
) -> Dict[str, Any]:
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
for attr in node.attribute:
if attr.name in attrs:
if attr.type == onnx.AttributeProto.INT:
if attr.type == AttributeProto.INT:
attrs[attr.name] = attr.i
elif attr.type == onnx.AttributeProto.INTS:
elif attr.type == AttributeProto.INTS:
attrs[attr.name] = attr.ints
elif attr.type == onnx.AttributeProto.FLOAT:
elif attr.type == AttributeProto.FLOAT:
attrs[attr.name] = attr.f
elif attr.type == onnx.AttributeProto.STRING:
elif attr.type == AttributeProto.STRING:
attrs[attr.name] = attr.s
elif attr.type == onnx.AttributeProto.TENSOR:
elif attr.type == AttributeProto.TENSOR:
attrs[attr.name] = attr.t
else:
assert False, "Unsupported Attribute Type: {}".format(attr.type)
return attrs
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
if tensor.data_type == onnx.TensorProto.INT32:
def _parse_data(tensor: TensorProto) -> List[int]:
if tensor.data_type == TensorProto.INT32:
return [int(i) for i in tensor.int32_data]
elif tensor.data_type == onnx.TensorProto.INT64:
elif tensor.data_type == TensorProto.INT64:
return [int(i) for i in tensor.int64_data]
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]:
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]

View File

@ -8,7 +8,7 @@ from onnx.helper import (
make_tensor_value_info,
)
from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
def make_and_import_model(graph: onnx.GraphProto):
@ -293,11 +293,20 @@ class TestStringMethods(unittest.TestCase):
parse_onnx(model)
def test_frontend(self):
handler = backend.GraphHandlerObj(runtime)
i = handler.tensor([1, 2, 3], 12)
w = handler.tensor([1, 3, 4], 12)
o = handler.tensor([1, 2, 4], 12)
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
handler = backend.GraphHandler(runtime)
a = handler.tensor([1, 2, 3], 12)
b = handler.tensor([1, 2, 3], 12)
c = handler.tensor([1, 2, 3], 12)
d = handler.tensor([1, 2, 3], 12)
e = handler.tensor([1, 2, 3], 12)
x = handler.add(
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
)
y = handler.tensor([3, 2, 1], 12)
handler.reshape(x, y, [3, 2, 1])
to_onnx(handler, "test_frontend")
if __name__ == "__main__":

View File

@ -1,9 +1,11 @@
#include "core/graph.h"
#include <algorithm>
#include <queue>
namespace infini {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
: runtime(runtime), sorted(false) {
map<UidBaseType, Tensor> tensorPool;
// Clone tensors
for (const auto &op : ops_in) {
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
}
void GraphObj::addOperatorAndConnect(const Operator &op) {
sorted = false;
ops.push_back(op);
for (auto &input : op->getInputs()) {
input->addInputOf(op);
@ -66,6 +69,53 @@ string GraphObj::toString() const {
return oss.str();
}
bool GraphObj::topo_sort() {
if (this->sorted)
return true;
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
std::vector<Operator> sorted;
while (!waiting.empty()) {
// Any node is move to sorted in this loop.
auto modified = false;
// Find head nodes.
for (auto it = waiting.begin(); it != waiting.end();) {
const auto &this_inputs = (*it)->getInputs();
// If none of the input tensors is in waiting list,
// this node is a head node.
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getOutputOf();
return src // If the source node is in the waiting list,
// means that this node is not the head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
modified = true;
sorted.emplace_back(std::move(*it));
it = waiting.erase(it);
} else {
++it;
}
}
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
}
void GraphObj::dataMalloc() {
for (auto &tensor : tensors) {
tensor->dataMalloc();
@ -73,15 +123,12 @@ void GraphObj::dataMalloc() {
}
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime);
tensors.emplace_back(tensor);
return tensor;
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
}
Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
tensors.emplace_back(tensor);
return tensor;
return tensors.emplace_back(tensor);
}
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
@ -98,4 +145,4 @@ OpVec GraphObj::getComputeOps() const {
return opList;
};
} // namespace infini
} // namespace infini

View File

@ -1,4 +1,8 @@
#include "core/graph_handler.h"
#include "operators/concat.h"
#include "operators/gather.h"
#include "operators/reduce_mean.h"
#include "operators/reshape.h"
#include <pybind11/stl.h>
#ifdef USE_CUDA
@ -21,88 +25,151 @@ void register_operator_timer(py::module &m) {
#endif
}
void export_values(py::module &m) {
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
py::enum_<ActType>(m, "ActType")
.value("Linear", ActType::None) // `None` is Python keyword
.VALUE(ActType, Relu)
.VALUE(ActType, Sigmoid)
.VALUE(ActType, Tanh)
.export_values();
py::enum_<OpType>(m, "OpType")
.VALUE(OpType, Unknown)
.VALUE(OpType, Conv)
.VALUE(OpType, Matmul)
.VALUE(OpType, ConvTrans)
.VALUE(OpType, G2BMM)
.VALUE(OpType, GBMM)
.VALUE(OpType, Pad)
.VALUE(OpType, Slice)
.VALUE(OpType, Concat)
.VALUE(OpType, Split)
.VALUE(OpType, Transpose)
.VALUE(OpType, Extend)
.VALUE(OpType, MaxPool)
.VALUE(OpType, AvgPool)
.VALUE(OpType, Add)
.VALUE(OpType, Sub)
.VALUE(OpType, Mul)
.VALUE(OpType, Div)
.VALUE(OpType, Pow)
.VALUE(OpType, Gather)
.VALUE(OpType, ReduceMean)
.VALUE(OpType, Reshape)
.VALUE(OpType, Flatten)
.VALUE(OpType, Identity)
.VALUE(OpType, BatchNorm)
.VALUE(OpType, Softmax)
.VALUE(OpType, Activation)
.VALUE(OpType, Relu)
.VALUE(OpType, Sigmoid)
.VALUE(OpType, Tanh)
.VALUE(OpType, Abs)
.VALUE(OpType, Resize)
.VALUE(OpType, MemBound)
.export_values();
#undef VALUE
}
static int tensor_dtype(Tensor t) {
if (t->getDType() == DataType::Float32)
return OnnxDType::FLOAT;
if (t->getDType() == DataType::UInt32)
return OnnxDType::UINT32;
if (t->getDType() == DataType::UInt8)
return OnnxDType::UINT8;
if (t->getDType() == DataType::Int8)
return OnnxDType::INT8;
if (t->getDType() == DataType::UInt16)
return OnnxDType::UINT16;
if (t->getDType() == DataType::Int16)
return OnnxDType::INT16;
if (t->getDType() == DataType::Int32)
return OnnxDType::INT32;
if (t->getDType() == DataType::Int64)
return OnnxDType::INT64;
IT_ASSERT(false, "Unsupported data type");
}
static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat);
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
}
static int gather_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Gather);
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
}
static vector<int> reduce_mean_axes_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
auto &set = dynamic_cast<const ReduceMeanObj *>(op.get())->getAxes();
return vector(set.begin(), set.end());
}
static Shape reshape_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Reshape);
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
}
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of)
.FUNCTION(gather_axis_of)
.FUNCTION(reduce_mean_axes_of);
#undef FUNCTION
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
py::enum_<ActType>(m, "ActType")
.value("Linear", ActType::None) // `None` is Python keyword
.value("Relu", ActType::Relu)
.value("Sigmoid", ActType::Sigmoid)
.value("Tanh", ActType::Tanh)
.export_values();
py::class_<Handler>(m, "GraphHandlerObj")
m, "CpuRuntime");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("shape", &TensorObj::getDims, policy::move)
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
policy::reference)
.def("outputs",
py::overload_cast<>(&OperatorObj::getOutputs, py::const_),
policy::reference);
py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
policy::move)
.def("tensor", &Handler::tensor, policy::move)
.def("conv", &Handler::conv, policy::move)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&Handler::matmul),
policy::move)
.def("batchNorm",
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
float, float, bool>(&Handler::batchNorm),
policy::move)
.def("maxPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::maxPool),
policy::move)
.def("avgPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::avgPool),
policy::move)
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
policy::move)
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
policy::move)
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
policy::move)
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
policy::move)
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
policy::move)
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
policy::move)
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
policy::move)
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
policy::move)
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
policy::move)
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
policy::move)
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
policy::move)
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
policy::move)
.def("reshape",
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
policy::move)
.def("concat",
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
policy::move)
.def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
policy::move)
.def("reduceMean",
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
bool>(&Handler::reduceMean),
policy::move)
.def("slice",
py::overload_cast<
Tensor, Tensor, const vector<int> &, const vector<int> &,
const optional<vector<int>> &, const optional<vector<int>> &>(
&Handler::slice),
policy::move)
.def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<vector<int>> &>(&Handler::pad),
policy::move)
.def("matmul", &Handler::matmul, policy::move)
.def("batchNorm", &Handler::batchNorm, policy::move)
.def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move)
.def("sub", &Handler::sub, policy::move)
.def("mul", &Handler::mul, policy::move)
.def("div", &Handler::div, policy::move)
.def("pow", &Handler::pow, policy::move)
.def("relu", &Handler::relu, policy::move)
.def("sigmoid", &Handler::sigmoid, policy::move)
.def("tanh", &Handler::tanh, policy::move)
.def("softmax", &Handler::softmax, policy::move)
.def("abs", &Handler::abs, policy::move)
.def("identity", &Handler::identity, policy::move)
.def("flatten", &Handler::flatten, policy::move)
.def("reshape", &Handler::reshape, policy::move)
.def("concat", &Handler::concat, policy::move)
.def("gather", &Handler::gather, policy::move)
.def("reduceMean", &Handler::reduceMean, policy::move)
.def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic);
}
@ -111,5 +178,7 @@ void init_graph_builder(py::module &m) {
PYBIND11_MODULE(backend, m) {
infini::register_operator_timer(m);
infini::export_values(m);
infini::export_functions(m);
infini::init_graph_builder(m);
}

View File

@ -1,6 +1,7 @@
#include "core/blob.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/unary.h"
#include "test.h"
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
EXPECT_TRUE(o0->equalData(ans));
}
TEST(Graph, topological) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
auto ops = std::vector{
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
g->addOpWithOutputs<AddObj>(abc, d, abcd),
g->addOpWithOutputs<AddObj>(ab, c, abc),
g->addOpWithOutputs<AddObj>(a, b, ab),
};
{
auto p = ops.begin();
auto q = g->getOperators().begin();
while (p != ops.end()) {
EXPECT_EQ(*p++, *q++);
}
}
EXPECT_TRUE(g->topo_sort());
{
auto p = ops.rbegin();
auto q = g->getOperators().begin();
while (p != ops.rend()) {
EXPECT_EQ(*p++, *q++);
}
}
} // namespace infini
TEST(Graph, perf_engine) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);