feat: 支持导入和推导通信算子

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-18 12:07:25 +08:00
parent 97439bd860
commit e637f9e7dd
4 changed files with 10 additions and 2 deletions

View File

@ -93,6 +93,7 @@ include_directories(RefactorGraph/src/00common/include)
include_directories(RefactorGraph/src/01graph_topo/include)
include_directories(RefactorGraph/src/02computation/include)
include_directories(RefactorGraph/src/03onnx/include)
include_directories(RefactorGraph/src/04communication/include)
endif()
include_directories(include)
@ -158,7 +159,7 @@ if(USE_PROTOBUF)
endif()
if(USE_REFACTOR_GRAPH)
target_link_libraries(InfiniTensor onnx)
target_link_libraries(InfiniTensor onnx communication)
endif()
target_link_libraries(InfiniTensor pybind11::embed)

@ -1 +1 @@
Subproject commit dd769552b295672be13893dc34ac370807ca2706
Subproject commit 953712f681dea08eed67092f53287e5f34c63a62

View File

@ -24,6 +24,11 @@ def build_graph(model: ModelProto) -> backend.Graph:
None,
)
x = set()
for node in model.graph.node:
x.add(node.op_type)
print(x)
return refactor_graph(
{node.name: (node.input, node.output) for node in model.graph.node},
{

View File

@ -1,4 +1,5 @@
#include "common/error_handler.h"
#include "communication/operators.h"
#include "computation/graph.h"
#include "onnx/operators.h"
#include <pybind11/numpy.h>
@ -185,6 +186,7 @@ graph(std::unordered_map<Name, std::pair<std::vector<Name>, std::vector<Name>>>
void register_refactor(py::module &m) {
onnx::register_();
communication::register_();
py::class_<DimExpr>(m, "DimExpr")
.def(py::init<int64_t>())