diff --git a/include/core/graph.h b/include/core/graph.h index 4ce8697a..8e317a8b 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -8,13 +8,10 @@ class GraphObj : public Object { protected: Runtime runtime; TensorVec tensors; - // TODO: whether to record input and output tensors - // TensorVec inputs; - // TensorVec outputs; OpVec ops; public: - GraphObj(Runtime runtime) : runtime(runtime){}; + explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){}; GraphObj(Runtime runtime, OpVec ops_in); string toString() const override; Runtime getRuntime() const { return runtime; } @@ -23,10 +20,23 @@ class GraphObj : public Object { Tensor addTensor(const Tensor &tensor); TensorVec addTensor(const TensorVec &tensors); Tensor cloneTensor(const Tensor &tensor) { - auto ret = addTensor(tensor->clone(runtime)); - return ret; + return addTensor(tensor->clone(runtime)); } + const TensorVec &getTensors() const { return tensors; } + const OpVec &getOperators() const { return ops; } + OpVec getComputeOps() const; + + /** + * Sort the nodes in topological order. + * It returns true if the sorting is successful. + * Otherwise false is returned, means that there are rings in the graph, + * so the topological sorting fails. + */ + bool topo_sort(); + + void dataMalloc(); + /** * @brief Add an operator and create its outputs. Output tensor arguments * should be empty Refs (e.g., nullptr). @@ -47,25 +57,27 @@ class GraphObj : public Object { return op; } - const TensorVec &getTensors() const { return tensors; } - const TensorVec getInputs() const { + /** + * @brief Gets input tensors of this graph. + */ + inline TensorVec getInputs() const { TensorVec ret; - for (auto t : tensors) + for (const auto &t : tensors) if (!t->getOutputOf()) ret.emplace_back(t); return ret; } - const TensorVec getOutputs() const { + + /** + * @brief Gets output tensors of this graph. + */ + inline TensorVec getOutputs() const { TensorVec ret; - for (auto t : tensors) + for (const auto &t : tensors) if (t->getInputOf().empty()) ret.emplace_back(t); return ret; } - const OpVec &getOperators() const { return ops; } - OpVec getComputeOps() const; - - void dataMalloc(); private: /** @@ -73,9 +85,10 @@ class GraphObj : public Object { */ void addOperatorAndConnect(const Operator &op); - // TODO: move to another class - // bool exportOnnx(const char *path); - // bool importOnnx(const char *net); + /** + * @brief If the nodes is sorted in topological order. + */ + bool sorted; }; } // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index b74f8e13..dc221042 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -38,12 +38,14 @@ class GraphHandlerObj { Tensor tensor(Shape dims, int dtype); + //------ operators + + inline OpVec operators() { return g->getOperators(); } + Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw, int sh, int sw, int dh, int dw); - Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act); - Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor scale, Tensor bias, float momentum, float eps, bool training); @@ -77,6 +79,10 @@ class GraphHandlerObj { Tensor pad(Tensor input, Tensor output, const vector &pads, const optional> &axes); + //------ modifiers + + inline bool topo_sort() { return g->topo_sort(); } + //------ runtime inline void data_malloc() { g->dataMalloc(); } diff --git a/include/operators/reduce_mean.h b/include/operators/reduce_mean.h index 76d3454b..ef74cd2e 100644 --- a/include/operators/reduce_mean.h +++ b/include/operators/reduce_mean.h @@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj { int numOutputs() const override { return 1; } bool isReduced(int idx) const; + const set &getAxes() const { return axes; } bool getKeepDims() const { return keepDims; } private: diff --git a/include/operators/reshape.h b/include/operators/reshape.h index 66fb1bda..31cc3576 100644 --- a/include/operators/reshape.h +++ b/include/operators/reshape.h @@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj { int numInputs() const override { return 1; } int numOutputs() const override { return 1; } + inline Shape getShape() const { return dims; } + private: vector getWorkloadVector() const override; vector getOpAttrVector() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a429368f..ce315685 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,17 +1,39 @@ -import onnx, backend +import backend +from onnx import ( + ModelProto, + TensorProto, + NodeProto, + AttributeProto, + TensorShapeProto, + ValueInfoProto, +) +from onnx.helper import ( + make_node, + make_tensor_value_info, + make_tensor, + make_graph, + make_model, +) +from onnx.checker import ( + check_graph, + check_model, + check_node, + check_value_info, + check_tensor, +) from onnx.shape_inference import infer_shapes -from typing import Dict, List, Any +from typing import Dict, List, Any, Tuple, Sequence from functools import reduce runtime = backend.cpu_runtime() -def from_onnx(model: onnx.ModelProto): +def from_onnx(model: ModelProto) -> backend.GraphHandler: model = infer_shapes(model) - handler = backend.GraphHandlerObj(runtime) + handler = backend.GraphHandler(runtime) - tensors: Dict[str, backend.TensorObj] = dict() - data: Dict[str, onnx.TensorProto] = dict() + tensors: Dict[str, backend.Tensor] = dict() + data: Dict[str, TensorProto] = dict() for input in model.graph.input: dims = _take_shape_dim(input.type.tensor_type.shape) @@ -303,7 +325,164 @@ def from_onnx(model: onnx.ModelProto): raise Exception('Unsupported operator "{}"'.format(node.op_type)) -def parse_onnx(model: onnx.ModelProto): +def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: + class Context: + # saves object names, including tensors and operators + names: Dict[Any, str] = dict() + # counts the occurrence times of each operator for naming + count_op: Dict[backend.OpType, int] = dict() + # counts input and output tensors for naming + count_in, count_out = 0, 0 + # saves nodes (operators) + nodes: List[NodeProto] = [] + # saves global input tensors + inputs: List[ValueInfoProto] = [] + # saves global output tensors + outputs: List[ValueInfoProto] = [] + # saves global input tensors + initializers: List[TensorProto] = [] + + def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: + ty = op.op_type() + name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1) + self.names[op] = name + self.count_op[ty] += 1 + return ty, name + + def push_output(self, name: str, tensor: backend.Tensor) -> str: + self.names[tensor] = name + return name + + def push_input(self, tensor: backend.Tensor) -> str: + name = self.names.get(tensor) + # means that this input is a global input + if name is None: + self.count_in += 1 + name = "input{}".format(self.count_in) + self.names[tensor] = name + shape = tensor.shape() + dtype = backend.tensor_dtype(tensor) + value_info = make_tensor_value_info(name, dtype, shape) + check_value_info(value_info) + self.inputs.append(value_info) + + return name + + def push_data_input( + self, + node_name: str, + attr_name: str, + elem_type: int, + shape: Sequence[int], + vals: Any, + ) -> str: + name = "{}_{}".format(node_name, attr_name) + value_info = make_tensor_value_info(name, elem_type, shape) + tensor = make_tensor(name, elem_type, shape, vals) + check_value_info(value_info) + check_tensor(tensor) + self.inputs.append(value_info) + self.initializers.append(tensor) + return name + + def push_node(self, node: NodeProto) -> None: + check_node(node) + self.nodes.append(node) + + def build(self, name: str) -> ModelProto: + print() + print(ctx.names) + print() + print(ctx.inputs) + print() + print(ctx.outputs) + print() + print(ctx.nodes) + + graph = make_graph( + self.nodes, name, self.inputs, self.outputs, self.initializers + ) + check_graph(graph) + + model = make_model(graph) + check_model(model) + + return model + + # 拓扑排序 + if not graph.topo_sort(): + raise Exception("Sorting fails") + + ops = graph.operators() # 图中所有算子(节点) + + ctx = Context() + + for op in ops: + ty, name = ctx.name_op(op) + inputs = [ctx.push_input(it) for it in op.inputs()] + outputs = [ + ctx.push_output("{}_{}".format(name, i), it) + for (i, it) in enumerate(op.outputs()) + ] + if ty == backend.OpType.Matmul: + ctx.push_node(make_node("MatMul", inputs, outputs, name)) + elif ty == backend.OpType.BatchNorm: + raise Exception("TODO") + elif ty == backend.OpType.MaxPool: + raise Exception("TODO") + elif ty == backend.OpType.AvgPool: + raise Exception("TODO") + elif ty in [ + backend.OpType.Add, + backend.OpType.Sub, + backend.OpType.Mul, + backend.OpType.Div, + backend.OpType.Pow, + backend.OpType.Relu, + backend.OpType.Sigmoid, + backend.OpType.Tanh, + backend.OpType.Softmax, + backend.OpType.Abs, + backend.OpType.Identity, + ]: + ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpType.Flatten: + raise Exception("TODO") + elif ty == backend.OpType.Reshape: + shape = backend.reshape_shape_of(op) + inputs.append( + ctx.push_data_input( + name, + "shape", + TensorProto.INT32, + [len(shape)], + shape, + ) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpType.Concat: + axis = backend.concat_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) + elif ty == backend.OpType.Gather: + axis = backend.gather_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) + elif ty == backend.OpType.ReduceMean: + axes = backend.reduce_mean_axes_of(op) + inputs.append( + ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1)) + elif ty == backend.OpType.Slice: + raise Exception("TODO") + elif ty == backend.OpType.Pad: + raise Exception("TODO") + else: + raise Exception("Unsupported OpType {}".format(ty.name)) + + return ctx.build(name) + + +def parse_onnx(model: ModelProto): print() for field in [ @@ -339,34 +518,32 @@ def parse_onnx(model: onnx.ModelProto): print(" {}".format(node.name)) -def _parse_attribute( - node: onnx.NodeProto, attrs: Dict[str, Any] = dict() -) -> Dict[str, Any]: +def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: if attr.name in attrs: - if attr.type == onnx.AttributeProto.INT: + if attr.type == AttributeProto.INT: attrs[attr.name] = attr.i - elif attr.type == onnx.AttributeProto.INTS: + elif attr.type == AttributeProto.INTS: attrs[attr.name] = attr.ints - elif attr.type == onnx.AttributeProto.FLOAT: + elif attr.type == AttributeProto.FLOAT: attrs[attr.name] = attr.f - elif attr.type == onnx.AttributeProto.STRING: + elif attr.type == AttributeProto.STRING: attrs[attr.name] = attr.s - elif attr.type == onnx.AttributeProto.TENSOR: + elif attr.type == AttributeProto.TENSOR: attrs[attr.name] = attr.t else: assert False, "Unsupported Attribute Type: {}".format(attr.type) return attrs -def _parse_data(tensor: onnx.TensorProto) -> List[int]: - if tensor.data_type == onnx.TensorProto.INT32: +def _parse_data(tensor: TensorProto) -> List[int]: + if tensor.data_type == TensorProto.INT32: return [int(i) for i in tensor.int32_data] - elif tensor.data_type == onnx.TensorProto.INT64: + elif tensor.data_type == TensorProto.INT64: return [int(i) for i in tensor.int64_data] else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) -def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]: +def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 46328f76..1f839256 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -8,7 +8,7 @@ from onnx.helper import ( make_tensor_value_info, ) from onnx.checker import check_model -from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime +from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx def make_and_import_model(graph: onnx.GraphProto): @@ -293,11 +293,20 @@ class TestStringMethods(unittest.TestCase): parse_onnx(model) def test_frontend(self): - handler = backend.GraphHandlerObj(runtime) - i = handler.tensor([1, 2, 3], 12) - w = handler.tensor([1, 3, 4], 12) - o = handler.tensor([1, 2, 4], 12) - handler.matmul(i, w, o, False, False, None, backend.ActType.Relu) + handler = backend.GraphHandler(runtime) + a = handler.tensor([1, 2, 3], 12) + b = handler.tensor([1, 2, 3], 12) + c = handler.tensor([1, 2, 3], 12) + d = handler.tensor([1, 2, 3], 12) + e = handler.tensor([1, 2, 3], 12) + + x = handler.add( + handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None + ) + y = handler.tensor([3, 2, 1], 12) + handler.reshape(x, y, [3, 2, 1]) + + to_onnx(handler, "test_frontend") if __name__ == "__main__": diff --git a/src/core/graph.cc b/src/core/graph.cc index a9edda64..04ce2581 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,9 +1,11 @@ #include "core/graph.h" +#include #include namespace infini { -GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) { +GraphObj::GraphObj(Runtime runtime, OpVec ops_in) + : runtime(runtime), sorted(false) { map tensorPool; // Clone tensors for (const auto &op : ops_in) { @@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) { } void GraphObj::addOperatorAndConnect(const Operator &op) { + sorted = false; ops.push_back(op); for (auto &input : op->getInputs()) { input->addInputOf(op); @@ -66,6 +69,53 @@ string GraphObj::toString() const { return oss.str(); } +bool GraphObj::topo_sort() { + if (this->sorted) + return true; + + // std::unordered_set inputs; + std::unordered_set waiting(this->ops.begin(), this->ops.end()); + std::vector sorted; + + while (!waiting.empty()) { + // Any node is move to sorted in this loop. + auto modified = false; + // Find head nodes. + for (auto it = waiting.begin(); it != waiting.end();) { + const auto &this_inputs = (*it)->getInputs(); + // If none of the input tensors is in waiting list, + // this node is a head node. + const auto is_head = std::all_of( + this_inputs.begin(), this_inputs.end(), [&](const auto &input) { + auto src = input->getOutputOf(); + return src // If the source node is in the waiting list, + // means that this node is not the head node. + ? waiting.find(src) == waiting.end() + // This tensor has no source node, + // it must be a input tensor. + : (/*inputs.insert(input),*/ true); + }); + // Moves head node to sorted. + if (is_head) { + modified = true; + sorted.emplace_back(std::move(*it)); + it = waiting.erase(it); + } else { + ++it; + } + } + // Waiting list never modifies during a pass, + // sorting fails. + if (!modified) { + return false; + } + } + + // Done. + this->ops = std::move(sorted); + return this->sorted = true; +} + void GraphObj::dataMalloc() { for (auto &tensor : tensors) { tensor->dataMalloc(); @@ -73,15 +123,12 @@ void GraphObj::dataMalloc() { } Tensor GraphObj::addTensor(Shape dim, DataType dtype) { - Tensor tensor = make_ref(dim, dtype, runtime); - tensors.emplace_back(tensor); - return tensor; + return tensors.emplace_back(make_ref(dim, dtype, runtime)); } Tensor GraphObj::addTensor(const Tensor &tensor) { IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch"); - tensors.emplace_back(tensor); - return tensor; + return tensors.emplace_back(tensor); } TensorVec GraphObj::addTensor(const TensorVec &tensors) { @@ -98,4 +145,4 @@ OpVec GraphObj::getComputeOps() const { return opList; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 37b7d5da..e545f43c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,4 +1,8 @@ #include "core/graph_handler.h" +#include "operators/concat.h" +#include "operators/gather.h" +#include "operators/reduce_mean.h" +#include "operators/reshape.h" #include #ifdef USE_CUDA @@ -21,88 +25,151 @@ void register_operator_timer(py::module &m) { #endif } +void export_values(py::module &m) { +#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME) + + py::enum_(m, "ActType") + .value("Linear", ActType::None) // `None` is Python keyword + .VALUE(ActType, Relu) + .VALUE(ActType, Sigmoid) + .VALUE(ActType, Tanh) + .export_values(); + + py::enum_(m, "OpType") + .VALUE(OpType, Unknown) + .VALUE(OpType, Conv) + .VALUE(OpType, Matmul) + .VALUE(OpType, ConvTrans) + .VALUE(OpType, G2BMM) + .VALUE(OpType, GBMM) + .VALUE(OpType, Pad) + .VALUE(OpType, Slice) + .VALUE(OpType, Concat) + .VALUE(OpType, Split) + .VALUE(OpType, Transpose) + .VALUE(OpType, Extend) + .VALUE(OpType, MaxPool) + .VALUE(OpType, AvgPool) + .VALUE(OpType, Add) + .VALUE(OpType, Sub) + .VALUE(OpType, Mul) + .VALUE(OpType, Div) + .VALUE(OpType, Pow) + .VALUE(OpType, Gather) + .VALUE(OpType, ReduceMean) + .VALUE(OpType, Reshape) + .VALUE(OpType, Flatten) + .VALUE(OpType, Identity) + .VALUE(OpType, BatchNorm) + .VALUE(OpType, Softmax) + .VALUE(OpType, Activation) + .VALUE(OpType, Relu) + .VALUE(OpType, Sigmoid) + .VALUE(OpType, Tanh) + .VALUE(OpType, Abs) + .VALUE(OpType, Resize) + .VALUE(OpType, MemBound) + .export_values(); + +#undef VALUE +} + +static int tensor_dtype(Tensor t) { + if (t->getDType() == DataType::Float32) + return OnnxDType::FLOAT; + if (t->getDType() == DataType::UInt32) + return OnnxDType::UINT32; + if (t->getDType() == DataType::UInt8) + return OnnxDType::UINT8; + if (t->getDType() == DataType::Int8) + return OnnxDType::INT8; + if (t->getDType() == DataType::UInt16) + return OnnxDType::UINT16; + if (t->getDType() == DataType::Int16) + return OnnxDType::INT16; + if (t->getDType() == DataType::Int32) + return OnnxDType::INT32; + if (t->getDType() == DataType::Int64) + return OnnxDType::INT64; + IT_ASSERT(false, "Unsupported data type"); +} + +static int concat_axis_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Concat); + return dynamic_cast(op.get())->getDim(); +} + +static int gather_axis_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Gather); + return dynamic_cast(op.get())->getAxis(); +} + +static vector reduce_mean_axes_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::ReduceMean); + auto &set = dynamic_cast(op.get())->getAxes(); + return vector(set.begin(), set.end()); +} + +static Shape reshape_shape_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Reshape); + return dynamic_cast(op.get())->getShape(); +} + +void export_functions(py::module &m) { +#define FUNCTION(NAME) def(#NAME, &NAME) + m.def("cpu_runtime", &CpuRuntimeObj::getInstance) + .FUNCTION(tensor_dtype) + .FUNCTION(reshape_shape_of) + .FUNCTION(concat_axis_of) + .FUNCTION(gather_axis_of) + .FUNCTION(reduce_mean_axes_of); +#undef FUNCTION +} + void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; - m.def("cpu_runtime", &CpuRuntimeObj::getInstance); - py::class_>(m, "RuntimeObj"); + py::class_>(m, "Runtime"); py::class_, RuntimeObj>( - m, "CpuRuntimeObj"); - py::class_>(m, "TensorObj"); - py::enum_(m, "ActType") - .value("Linear", ActType::None) // `None` is Python keyword - .value("Relu", ActType::Relu) - .value("Sigmoid", ActType::Sigmoid) - .value("Tanh", ActType::Tanh) - .export_values(); - py::class_(m, "GraphHandlerObj") + m, "CpuRuntime"); + py::class_>(m, "Tensor") + .def("shape", &TensorObj::getDims, policy::move) + .def("src", &TensorObj::getOutputOf, policy::move); + py::class_>(m, "Operator") + .def("op_type", &OperatorObj::getOpType, policy::automatic) + .def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_), + policy::reference) + .def("outputs", + py::overload_cast<>(&OperatorObj::getOutputs, py::const_), + policy::reference); + py::class_(m, "GraphHandler") .def(py::init()) - .def("tensor", py::overload_cast(&Handler::tensor), - policy::move) + .def("tensor", &Handler::tensor, policy::move) .def("conv", &Handler::conv, policy::move) - .def("matmul", - py::overload_cast(&Handler::matmul), - policy::move) - .def("batchNorm", - py::overload_cast(&Handler::batchNorm), - policy::move) - .def("maxPool", - py::overload_cast(&Handler::maxPool), - policy::move) - .def("avgPool", - py::overload_cast(&Handler::avgPool), - policy::move) - .def("add", py::overload_cast(&Handler::add), - policy::move) - .def("sub", py::overload_cast(&Handler::sub), - policy::move) - .def("mul", py::overload_cast(&Handler::mul), - policy::move) - .def("div", py::overload_cast(&Handler::div), - policy::move) - .def("pow", py::overload_cast(&Handler::pow), - policy::move) - .def("relu", py::overload_cast(&Handler::relu), - policy::move) - .def("sigmoid", py::overload_cast(&Handler::sigmoid), - policy::move) - .def("tanh", py::overload_cast(&Handler::tanh), - policy::move) - .def("softmax", py::overload_cast(&Handler::softmax), - policy::move) - .def("abs", py::overload_cast(&Handler::abs), - policy::move) - .def("identity", py::overload_cast(&Handler::identity), - policy::move) - .def("flatten", py::overload_cast(&Handler::flatten), - policy::move) - .def("reshape", - py::overload_cast(&Handler::reshape), - policy::move) - .def("concat", - py::overload_cast(&Handler::concat), - policy::move) - .def("gather", - py::overload_cast(&Handler::gather), - policy::move) - .def("reduceMean", - py::overload_cast> &, - bool>(&Handler::reduceMean), - policy::move) - .def("slice", - py::overload_cast< - Tensor, Tensor, const vector &, const vector &, - const optional> &, const optional> &>( - &Handler::slice), - policy::move) - .def("pad", - py::overload_cast &, - const optional> &>(&Handler::pad), - policy::move) + .def("matmul", &Handler::matmul, policy::move) + .def("batchNorm", &Handler::batchNorm, policy::move) + .def("maxPool", &Handler::maxPool, policy::move) + .def("avgPool", &Handler::avgPool, policy::move) + .def("add", &Handler::add, policy::move) + .def("sub", &Handler::sub, policy::move) + .def("mul", &Handler::mul, policy::move) + .def("div", &Handler::div, policy::move) + .def("pow", &Handler::pow, policy::move) + .def("relu", &Handler::relu, policy::move) + .def("sigmoid", &Handler::sigmoid, policy::move) + .def("tanh", &Handler::tanh, policy::move) + .def("softmax", &Handler::softmax, policy::move) + .def("abs", &Handler::abs, policy::move) + .def("identity", &Handler::identity, policy::move) + .def("flatten", &Handler::flatten, policy::move) + .def("reshape", &Handler::reshape, policy::move) + .def("concat", &Handler::concat, policy::move) + .def("gather", &Handler::gather, policy::move) + .def("reduceMean", &Handler::reduceMean, policy::move) + .def("slice", &Handler::slice, policy::move) + .def("pad", &Handler::pad, policy::move) + .def("topo_sort", &Handler::topo_sort, policy::automatic) + .def("operators", &Handler::operators, policy::move) .def("data_malloc", &Handler::data_malloc, policy::automatic) .def("run", &Handler::run, policy::automatic); } @@ -111,5 +178,7 @@ void init_graph_builder(py::module &m) { PYBIND11_MODULE(backend, m) { infini::register_operator_timer(m); + infini::export_values(m); + infini::export_functions(m); infini::init_graph_builder(m); } diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index d208c21f..65bcf68a 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -1,6 +1,7 @@ #include "core/blob.h" #include "core/graph.h" #include "core/runtime.h" +#include "operators/element_wise.h" #include "operators/matmul.h" #include "operators/unary.h" #include "test.h" @@ -36,6 +37,45 @@ TEST(Graph, build_and_run) { EXPECT_TRUE(o0->equalData(ans)); } +TEST(Graph, topological) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32); + + auto ops = std::vector{ + g->addOpWithOutputs(abcd, e, abcde), + g->addOpWithOutputs(abc, d, abcd), + g->addOpWithOutputs(ab, c, abc), + g->addOpWithOutputs(a, b, ab), + }; + + { + auto p = ops.begin(); + auto q = g->getOperators().begin(); + while (p != ops.end()) { + EXPECT_EQ(*p++, *q++); + } + } + + EXPECT_TRUE(g->topo_sort()); + + { + auto p = ops.rbegin(); + auto q = g->getOperators().begin(); + while (p != ops.rend()) { + EXPECT_EQ(*p++, *q++); + } + } +} // namespace infini + TEST(Graph, perf_engine) { Runtime runtime = CpuRuntimeObj::getInstance(); Graph g = make_ref(runtime);