forked from jiuyuan/InfiniTensor
feat: 支持导入和推导通信算子
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
97439bd860
commit
e637f9e7dd
|
@ -93,6 +93,7 @@ include_directories(RefactorGraph/src/00common/include)
|
|||
include_directories(RefactorGraph/src/01graph_topo/include)
|
||||
include_directories(RefactorGraph/src/02computation/include)
|
||||
include_directories(RefactorGraph/src/03onnx/include)
|
||||
include_directories(RefactorGraph/src/04communication/include)
|
||||
endif()
|
||||
|
||||
include_directories(include)
|
||||
|
@ -158,7 +159,7 @@ if(USE_PROTOBUF)
|
|||
endif()
|
||||
|
||||
if(USE_REFACTOR_GRAPH)
|
||||
target_link_libraries(InfiniTensor onnx)
|
||||
target_link_libraries(InfiniTensor onnx communication)
|
||||
endif()
|
||||
|
||||
target_link_libraries(InfiniTensor pybind11::embed)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit dd769552b295672be13893dc34ac370807ca2706
|
||||
Subproject commit 953712f681dea08eed67092f53287e5f34c63a62
|
|
@ -24,6 +24,11 @@ def build_graph(model: ModelProto) -> backend.Graph:
|
|||
None,
|
||||
)
|
||||
|
||||
x = set()
|
||||
for node in model.graph.node:
|
||||
x.add(node.op_type)
|
||||
print(x)
|
||||
|
||||
return refactor_graph(
|
||||
{node.name: (node.input, node.output) for node in model.graph.node},
|
||||
{
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "common/error_handler.h"
|
||||
#include "communication/operators.h"
|
||||
#include "computation/graph.h"
|
||||
#include "onnx/operators.h"
|
||||
#include <pybind11/numpy.h>
|
||||
|
@ -185,6 +186,7 @@ graph(std::unordered_map<Name, std::pair<std::vector<Name>, std::vector<Name>>>
|
|||
|
||||
void register_refactor(py::module &m) {
|
||||
onnx::register_();
|
||||
communication::register_();
|
||||
|
||||
py::class_<DimExpr>(m, "DimExpr")
|
||||
.def(py::init<int64_t>())
|
||||
|
|
Loading…
Reference in New Issue