forked from jiuyuan/InfiniTensor
feat: 创建 pyinfinitensor 前端
- python 前端项目结构及打包和安装脚本 - 后端编译出 so 改名为 backend,增加 GraphHandler 修改图结构 - ci 支持测试这些功能 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
c7ec9ee6e7
commit
296fcc5aa0
|
@ -6,8 +6,11 @@ on:
|
|||
- 'LICENSE'
|
||||
|
||||
env:
|
||||
protobuf_download: https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protobuf-cpp-3.21.12.tar.gz
|
||||
protobuf_version: 3.21.12
|
||||
protobuf-download: https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protobuf-cpp-3.21.12.tar.gz
|
||||
protobuf-version: "3.21.12"
|
||||
python-version: "3.10"
|
||||
resnet-download: https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx
|
||||
resnet-file: resnet18-v2-7.onnx
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
@ -18,6 +21,11 @@ jobs:
|
|||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Python ${{ env.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.python-version }}
|
||||
|
||||
- name: Install libdw
|
||||
run: sudo apt-get update && sudo apt-get install libdw-dev
|
||||
|
||||
|
@ -25,33 +33,38 @@ jobs:
|
|||
id: cache-protobuf
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: protobuf-${{ env.protobuf_version }}
|
||||
key: protobuf-${{ env.protobuf_version }}
|
||||
path: protobuf-${{ env.protobuf-version }}
|
||||
key: protobuf-${{ env.protobuf-version }}
|
||||
|
||||
- name: Download and compile protobuf
|
||||
if: steps.cache-protobuf.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
wget ${{ env.protobuf_download }}
|
||||
tar xf protobuf-cpp-${{ env.protobuf_version }}.tar.gz
|
||||
cd protobuf-${{ env.protobuf_version }}
|
||||
wget ${{ env.protobuf-download }}
|
||||
tar xf protobuf-cpp-${{ env.protobuf-version }}.tar.gz
|
||||
cd protobuf-${{ env.protobuf-version }}
|
||||
./autogen.sh
|
||||
./configure CFLAGS="-fPIC" CXXFLAGS="-fPIC"
|
||||
make -j2
|
||||
make -j8
|
||||
|
||||
- name: Install protobuf
|
||||
run: |
|
||||
cd protobuf-${{ env.protobuf_version }}
|
||||
cd protobuf-${{ env.protobuf-version }}
|
||||
sudo make install
|
||||
sudo ldconfig
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make -j2
|
||||
run: make
|
||||
|
||||
- name: Test cpu
|
||||
run: make test-cpp
|
||||
|
||||
- name: Install python-frontend
|
||||
run: |
|
||||
cd build
|
||||
make test
|
||||
python -m pip install --upgrade pip
|
||||
make install-python
|
||||
|
||||
- name: Download test model
|
||||
run: wget ${{ env.resnet-download }}
|
||||
|
||||
- name: Test onnx frontend
|
||||
run: make test-onnx
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
name: clang-format Check
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'LICENSE'
|
||||
|
||||
paths:
|
||||
- 'include/**'
|
||||
- 'src/**'
|
||||
- 'test/**'
|
||||
- '3rd-party/**'
|
||||
- '.clang-format'
|
||||
jobs:
|
||||
formatting-check:
|
||||
name: Formatting Check
|
||||
|
|
|
@ -37,4 +37,8 @@ build_debug/
|
|||
.vscode/
|
||||
|
||||
# python
|
||||
*.pyc
|
||||
*.egg-info
|
||||
*.pyc
|
||||
|
||||
# onnx model
|
||||
*.onnx
|
||||
|
|
|
@ -96,8 +96,8 @@ target_link_libraries(InfiniTensor pybind11::embed)
|
|||
|
||||
# Python bindings
|
||||
file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc)
|
||||
pybind11_add_module(pyinfinitensor MODULE ${FFIS})
|
||||
target_link_libraries(pyinfinitensor PRIVATE InfiniTensor)
|
||||
pybind11_add_module(backend MODULE ${FFIS})
|
||||
target_link_libraries(backend PRIVATE InfiniTensor)
|
||||
|
||||
if(USE_BACKTRACE)
|
||||
add_definitions(-D BACKWARD_TRACE)
|
||||
|
@ -134,13 +134,13 @@ if(USE_BANG)
|
|||
set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development")
|
||||
endif()
|
||||
message(STATUS "NEUWARE_HOME: ${NEUWARE_HOME}")
|
||||
|
||||
|
||||
include_directories("${NEUWARE_HOME}/include")
|
||||
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
|
||||
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
.PHONY : build clean install-python test-cpp test-onnx
|
||||
|
||||
TYPE ?= release
|
||||
|
||||
build:
|
||||
mkdir -p build/$(TYPE)
|
||||
cd build/$(TYPE) && cmake -DCMAKE_BUILD_TYPE=$(TYPE) ../.. && make -j8
|
||||
|
||||
clean:
|
||||
rm -rf build
|
||||
|
||||
install-python: build
|
||||
cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor
|
||||
pip install pyinfinitensor/
|
||||
|
||||
test-cpp: build
|
||||
@echo
|
||||
cd build/$(TYPE) && make test
|
||||
|
||||
test-onnx:
|
||||
@echo
|
||||
python3 pyinfinitensor/tests/test_onnx.py
|
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
// 借用 onnx 的定义减小传递开销,来自 onnx 但不只用于 onnx。
|
||||
//
|
||||
// see https://onnx.ai/onnx/intro/concepts.html#element-type
|
||||
enum OnnxDType : int {
|
||||
UNDEFINED = 0,
|
||||
FLOAT,
|
||||
UINT8,
|
||||
INT8,
|
||||
UINT16,
|
||||
INT16,
|
||||
INT32,
|
||||
INT64,
|
||||
STRING,
|
||||
BOOL,
|
||||
FLOAT16,
|
||||
DOUBLE,
|
||||
UINT32,
|
||||
UINT64,
|
||||
COMPLEX64,
|
||||
COMPLEX128,
|
||||
BFLOAT16,
|
||||
};
|
||||
|
||||
class GraphHandlerObj {
|
||||
Graph g;
|
||||
|
||||
public:
|
||||
GraphHandlerObj(Runtime runtime)
|
||||
: g(make_ref<GraphObj>(std::move(runtime))) {}
|
||||
|
||||
Tensor tensor(Shape dims, int dtype);
|
||||
|
||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||
Tensor bias, ActType act);
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -10,6 +10,7 @@ class TensorBaseObj;
|
|||
class TensorObj;
|
||||
class OperatorObj;
|
||||
class GraphObj;
|
||||
class GraphHandlerObj;
|
||||
class RuntimeObj;
|
||||
class BlobObj;
|
||||
|
||||
|
@ -17,6 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
|
|||
using Tensor = Ref<TensorObj>;
|
||||
using Operator = Ref<OperatorObj>;
|
||||
using Graph = Ref<GraphObj>;
|
||||
using GraphHandler = Ref<GraphHandlerObj>;
|
||||
using Runtime = Ref<RuntimeObj>;
|
||||
using Blob = Ref<BlobObj>;
|
||||
enum class OpType;
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# pyinfinitensor
|
||||
|
||||
InfiniTensor 的 Python 前端。
|
||||
|
||||
- [ ] 从 ONNX 导入模型
|
||||
- [ ] 从 Pytorch 导入模型
|
||||
- [ ] 从 PaddlePaddle 导入模型
|
||||
- [ ] 模型导出到 ONNX
|
||||
- [ ] 模型导出到 Pytorch
|
||||
- [ ] 模型导出到 PaddlePaddle
|
||||
|
||||
## python 工程结构及打包方法
|
||||
|
||||
本项目使用 [pyproject.toml] 文件定义,目录结构采用 [src 布局](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#src-layout),后端库会被[打包](https://setuptools.pypa.io/en/latest/userguide/datafiles.html#package-data),并支持[自动的依赖安装](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#declaring-required-dependency)。
|
||||
|
||||
参考 [setuptools 的文档](https://setuptools.pypa.io/en/latest/userguide/index.html)。
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "pyinfinitensor"
|
||||
version = "0.0.0"
|
||||
authors = [{ name = "YdrMaster", email = "ydrml@hotmail.com" }]
|
||||
description = "Python frontend of InfiniTensor"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
keywords = ["optimizer"]
|
||||
license = { text = "Apache" }
|
||||
classifiers = ["Programming Language :: Python :: 3"]
|
||||
dependencies = ["onnx"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
pyinfinitensor = ["*.so"]
|
|
@ -0,0 +1,8 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(str(__file__)).parent))
|
||||
|
||||
import backend
|
||||
|
||||
print("import backend: {}".format(backend))
|
|
@ -0,0 +1,66 @@
|
|||
import onnx
|
||||
import backend
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
||||
|
||||
def from_onnx(model: onnx.ModelProto):
|
||||
handler = backend.GraphHandlerObj(runtime)
|
||||
|
||||
tensors = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
||||
tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type)
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
|
||||
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
|
||||
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "MatMul":
|
||||
handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors[node.output[0]],
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
|
||||
|
||||
def parse_onnx(model: onnx.ModelProto):
|
||||
print()
|
||||
|
||||
for field in [
|
||||
"doc_string",
|
||||
"domain",
|
||||
"functions",
|
||||
"metadata_props",
|
||||
"model_version",
|
||||
"producer_name",
|
||||
"producer_version",
|
||||
"training_info",
|
||||
]:
|
||||
print("{}: {}".format(field, getattr(model, field)))
|
||||
|
||||
print("ir_version:", model.ir_version)
|
||||
for opset in model.opset_import:
|
||||
print("opset domain={} version={}".format(opset.domain, opset.version))
|
||||
|
||||
print("layout:")
|
||||
for node in model.graph.node:
|
||||
print(
|
||||
' {o} <- {op}"{name}"{a} <- {i}'.format(
|
||||
name=node.name,
|
||||
op=node.op_type,
|
||||
i=node.input,
|
||||
o=node.output,
|
||||
a=[a.name for a in node.attribute],
|
||||
)
|
||||
)
|
||||
|
||||
print("weight:")
|
||||
for node in model.graph.initializer:
|
||||
print(" {}".format(node.name))
|
|
@ -0,0 +1,43 @@
|
|||
import os, onnx, unittest
|
||||
from onnx import TensorProto
|
||||
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
||||
|
||||
|
||||
class TestStringMethods(unittest.TestCase):
|
||||
def test_load(self):
|
||||
model_file = next(
|
||||
(name for name in os.listdir() if name.endswith(".onnx")), None
|
||||
)
|
||||
if model_file != None:
|
||||
print(
|
||||
"model: {file}({size:.2f} MiB)".format(
|
||||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
)
|
||||
)
|
||||
parse_onnx(onnx.load(model_file))
|
||||
|
||||
def test_import(self):
|
||||
i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 2, 3])
|
||||
w = make_tensor_value_info("w", TensorProto.FLOAT, [1, 3, 4])
|
||||
o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 4])
|
||||
matmul = make_node("MatMul", ["i", "w"], ["o"], name="matmul")
|
||||
graph = make_graph([matmul], "mm", [i, w], [o])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
print(model)
|
||||
|
||||
from_onnx(model)
|
||||
parse_onnx(model)
|
||||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandlerObj(runtime)
|
||||
i = handler.tensor([1, 2, 3], 12)
|
||||
w = handler.tensor([1, 3, 4], 12)
|
||||
o = handler.tensor([1, 2, 4], 12)
|
||||
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,37 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
static DataType dtype_repr_convert(int);
|
||||
|
||||
Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
|
||||
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||
bool transB, Tensor bias, ActType act) {
|
||||
if (y) {
|
||||
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
|
||||
transB, std::move(bias), act);
|
||||
return y;
|
||||
} else {
|
||||
return g
|
||||
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
|
||||
std::move(bias), act)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
case OnnxDType::FLOAT:
|
||||
return DataType::Float32;
|
||||
case OnnxDType::UINT32:
|
||||
return DataType::UInt32;
|
||||
default:
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -1,7 +1,10 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace infini {
|
||||
|
@ -18,6 +21,34 @@ void register_operator_timer(py::module &m) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntimeObj");
|
||||
py::class_<Shape>(m, "Shape");
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
||||
py::class_<Tensor>(m, "Tensor");
|
||||
py::enum_<ActType>(m, "ActType")
|
||||
.value("Linear", ActType::None) // None 是 Python 关键字,不能用
|
||||
.value("Relu", ActType::Relu)
|
||||
.value("Sigmoid", ActType::Sigmoid)
|
||||
.value("Tanh", ActType::Tanh)
|
||||
.export_values();
|
||||
py::class_<GraphHandler>(m, "GraphHandler");
|
||||
py::class_<GraphHandlerObj>(m, "GraphHandlerObj")
|
||||
.def(py::init<Runtime>())
|
||||
.def("tensor", py::overload_cast<Shape, int>(&GraphHandlerObj::tensor),
|
||||
policy::reference_internal)
|
||||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&GraphHandlerObj::matmul),
|
||||
policy::reference_internal);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
||||
PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); }
|
||||
PYBIND11_MODULE(backend, m) {
|
||||
infini::register_operator_timer(m);
|
||||
infini::init_graph_builder(m);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
def inc(x: int) -> int:
|
||||
return x + 1
|
|
@ -0,0 +1,16 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "core/runtime.h"
|
||||
#include <test.h>
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Handler, matmul) {
|
||||
auto runtime = CpuRuntimeObj::getInstance();
|
||||
auto handler = make_ref<GraphHandlerObj>(runtime);
|
||||
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
|
||||
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
|
||||
auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32);
|
||||
handler->matmul(i, w, o, false, false, nullptr, ActType::None);
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,16 @@
|
|||
#include <filesystem>
|
||||
#include <pybind11/embed.h>
|
||||
#include <test.h>
|
||||
|
||||
TEST(Python, pybind) {
|
||||
namespace fs = std::filesystem;
|
||||
namespace py = pybind11;
|
||||
using mod = py::module;
|
||||
|
||||
py::scoped_interpreter _python;
|
||||
|
||||
auto sys_path_append = mod::import("sys").attr("path").attr("append");
|
||||
sys_path_append(fs::path(__FILE__).parent_path().c_str());
|
||||
auto ans = mod::import("python").attr("inc")(1);
|
||||
EXPECT_EQ(ans.cast<int>(), 2);
|
||||
}
|
Loading…
Reference in New Issue