feat: 创建 pyinfinitensor 前端

- python 前端项目结构及打包和安装脚本
- 后端编译出 so 改名为 backend,增加 GraphHandler 修改图结构
- ci 支持测试这些功能

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-12 08:23:49 +08:00
parent c7ec9ee6e7
commit 296fcc5aa0
18 changed files with 370 additions and 26 deletions

View File

@ -6,8 +6,11 @@ on:
- 'LICENSE'
env:
protobuf_download: https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protobuf-cpp-3.21.12.tar.gz
protobuf_version: 3.21.12
protobuf-download: https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protobuf-cpp-3.21.12.tar.gz
protobuf-version: "3.21.12"
python-version: "3.10"
resnet-download: https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx
resnet-file: resnet18-v2-7.onnx
jobs:
build:
@ -18,6 +21,11 @@ jobs:
with:
submodules: recursive
- name: Set up Python ${{ env.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.python-version }}
- name: Install libdw
run: sudo apt-get update && sudo apt-get install libdw-dev
@ -25,33 +33,38 @@ jobs:
id: cache-protobuf
uses: actions/cache@v3
with:
path: protobuf-${{ env.protobuf_version }}
key: protobuf-${{ env.protobuf_version }}
path: protobuf-${{ env.protobuf-version }}
key: protobuf-${{ env.protobuf-version }}
- name: Download and compile protobuf
if: steps.cache-protobuf.outputs.cache-hit != 'true'
run: |
wget ${{ env.protobuf_download }}
tar xf protobuf-cpp-${{ env.protobuf_version }}.tar.gz
cd protobuf-${{ env.protobuf_version }}
wget ${{ env.protobuf-download }}
tar xf protobuf-cpp-${{ env.protobuf-version }}.tar.gz
cd protobuf-${{ env.protobuf-version }}
./autogen.sh
./configure CFLAGS="-fPIC" CXXFLAGS="-fPIC"
make -j2
make -j8
- name: Install protobuf
run: |
cd protobuf-${{ env.protobuf_version }}
cd protobuf-${{ env.protobuf-version }}
sudo make install
sudo ldconfig
- name: Build
run: |
mkdir build
cd build
cmake ..
make -j2
run: make
- name: Test cpu
run: make test-cpp
- name: Install python-frontend
run: |
cd build
make test
python -m pip install --upgrade pip
make install-python
- name: Download test model
run: wget ${{ env.resnet-download }}
- name: Test onnx frontend
run: make test-onnx

View File

@ -1,10 +1,12 @@
name: clang-format Check
on:
pull_request:
paths-ignore:
- '**.md'
- 'LICENSE'
paths:
- 'include/**'
- 'src/**'
- 'test/**'
- '3rd-party/**'
- '.clang-format'
jobs:
formatting-check:
name: Formatting Check

6
.gitignore vendored
View File

@ -37,4 +37,8 @@ build_debug/
.vscode/
# python
*.pyc
*.egg-info
*.pyc
# onnx model
*.onnx

View File

@ -96,8 +96,8 @@ target_link_libraries(InfiniTensor pybind11::embed)
# Python bindings
file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc)
pybind11_add_module(pyinfinitensor MODULE ${FFIS})
target_link_libraries(pyinfinitensor PRIVATE InfiniTensor)
pybind11_add_module(backend MODULE ${FFIS})
target_link_libraries(backend PRIVATE InfiniTensor)
if(USE_BACKTRACE)
add_definitions(-D BACKWARD_TRACE)
@ -134,13 +134,13 @@ if(USE_BANG)
set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development")
endif()
message(STATUS "NEUWARE_HOME: ${NEUWARE_HOME}")
include_directories("${NEUWARE_HOME}/include")
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")

22
Makefile Normal file
View File

@ -0,0 +1,22 @@
.PHONY : build clean install-python test-cpp test-onnx
TYPE ?= release
build:
mkdir -p build/$(TYPE)
cd build/$(TYPE) && cmake -DCMAKE_BUILD_TYPE=$(TYPE) ../.. && make -j8
clean:
rm -rf build
install-python: build
cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor
pip install pyinfinitensor/
test-cpp: build
@echo
cd build/$(TYPE) && make test
test-onnx:
@echo
python3 pyinfinitensor/tests/test_onnx.py

View File

@ -0,0 +1,44 @@
#pragma once
#include "core/graph.h"
#include "core/runtime.h"
namespace infini {
// 借用 onnx 的定义减小传递开销,来自 onnx 但不只用于 onnx。
//
// see https://onnx.ai/onnx/intro/concepts.html#element-type
enum OnnxDType : int {
UNDEFINED = 0,
FLOAT,
UINT8,
INT8,
UINT16,
INT16,
INT32,
INT64,
STRING,
BOOL,
FLOAT16,
DOUBLE,
UINT32,
UINT64,
COMPLEX64,
COMPLEX128,
BFLOAT16,
};
class GraphHandlerObj {
Graph g;
public:
GraphHandlerObj(Runtime runtime)
: g(make_ref<GraphObj>(std::move(runtime))) {}
Tensor tensor(Shape dims, int dtype);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act);
};
} // namespace infini

View File

@ -10,6 +10,7 @@ class TensorBaseObj;
class TensorObj;
class OperatorObj;
class GraphObj;
class GraphHandlerObj;
class RuntimeObj;
class BlobObj;
@ -17,6 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
using Tensor = Ref<TensorObj>;
using Operator = Ref<OperatorObj>;
using Graph = Ref<GraphObj>;
using GraphHandler = Ref<GraphHandlerObj>;
using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>;
enum class OpType;

16
pyinfinitensor/README.md Normal file
View File

@ -0,0 +1,16 @@
# pyinfinitensor
InfiniTensor 的 Python 前端。
- [ ] 从 ONNX 导入模型
- [ ] 从 Pytorch 导入模型
- [ ] 从 PaddlePaddle 导入模型
- [ ] 模型导出到 ONNX
- [ ] 模型导出到 Pytorch
- [ ] 模型导出到 PaddlePaddle
## python 工程结构及打包方法
本项目使用 [pyproject.toml] 文件定义,目录结构采用 [src 布局](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#src-layout),后端库会被[打包](https://setuptools.pypa.io/en/latest/userguide/datafiles.html#package-data),并支持[自动的依赖安装](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#declaring-required-dependency)。
参考 [setuptools 的文档](https://setuptools.pypa.io/en/latest/userguide/index.html)。

View File

@ -0,0 +1 @@


View File

@ -0,0 +1,21 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "pyinfinitensor"
version = "0.0.0"
authors = [{ name = "YdrMaster", email = "ydrml@hotmail.com" }]
description = "Python frontend of InfiniTensor"
readme = "README.md"
requires-python = ">=3.8"
keywords = ["optimizer"]
license = { text = "Apache" }
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["onnx"]
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.package-data]
pyinfinitensor = ["*.so"]

View File

@ -0,0 +1,8 @@
import sys
from pathlib import Path
sys.path.append(str(Path(str(__file__)).parent))
import backend
print("import backend: {}".format(backend))

View File

@ -0,0 +1,66 @@
import onnx
import backend
runtime = backend.cpu_runtime()
def from_onnx(model: onnx.ModelProto):
handler = backend.GraphHandlerObj(runtime)
tensors = dict()
for input in model.graph.input:
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type)
for output in model.graph.output:
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
for node in model.graph.node:
if node.op_type == "MatMul":
handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.output[0]],
False,
False,
None,
backend.ActType.Linear,
)
def parse_onnx(model: onnx.ModelProto):
print()
for field in [
"doc_string",
"domain",
"functions",
"metadata_props",
"model_version",
"producer_name",
"producer_version",
"training_info",
]:
print("{}: {}".format(field, getattr(model, field)))
print("ir_version:", model.ir_version)
for opset in model.opset_import:
print("opset domain={} version={}".format(opset.domain, opset.version))
print("layout:")
for node in model.graph.node:
print(
' {o} <- {op}"{name}"{a} <- {i}'.format(
name=node.name,
op=node.op_type,
i=node.input,
o=node.output,
a=[a.name for a in node.attribute],
)
)
print("weight:")
for node in model.graph.initializer:
print(" {}".format(node.name))

View File

@ -0,0 +1,43 @@
import os, onnx, unittest
from onnx import TensorProto
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
class TestStringMethods(unittest.TestCase):
def test_load(self):
model_file = next(
(name for name in os.listdir() if name.endswith(".onnx")), None
)
if model_file != None:
print(
"model: {file}({size:.2f} MiB)".format(
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
)
)
parse_onnx(onnx.load(model_file))
def test_import(self):
i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 2, 3])
w = make_tensor_value_info("w", TensorProto.FLOAT, [1, 3, 4])
o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 4])
matmul = make_node("MatMul", ["i", "w"], ["o"], name="matmul")
graph = make_graph([matmul], "mm", [i, w], [o])
model = make_model(graph)
check_model(model)
print(model)
from_onnx(model)
parse_onnx(model)
def test_frontend(self):
handler = backend.GraphHandlerObj(runtime)
i = handler.tensor([1, 2, 3], 12)
w = handler.tensor([1, 3, 4], 12)
o = handler.tensor([1, 2, 4], 12)
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
if __name__ == "__main__":
unittest.main()

37
src/core/graph_handler.cc Normal file
View File

@ -0,0 +1,37 @@
#include "core/graph_handler.h"
#include "operators/matmul.h"
namespace infini {
static DataType dtype_repr_convert(int);
Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
}
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
bool transB, Tensor bias, ActType act) {
if (y) {
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
transB, std::move(bias), act);
return y;
} else {
return g
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
std::move(bias), act)
->getOutput();
}
}
static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT:
return DataType::Float32;
case OnnxDType::UINT32:
return DataType::UInt32;
default:
IT_ASSERT(false, "Unsupported data type");
}
}
} // namespace infini

View File

@ -1,7 +1,10 @@
#include "core/graph_handler.h"
#include <pybind11/stl.h>
#ifdef USE_CUDA
#include "cuda/operator_timer.h"
#endif
namespace py = pybind11;
namespace infini {
@ -18,6 +21,34 @@ void register_operator_timer(py::module &m) {
#endif
}
void init_graph_builder(py::module &m) {
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj");
py::class_<Shape>(m, "Shape");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
py::class_<Tensor>(m, "Tensor");
py::enum_<ActType>(m, "ActType")
.value("Linear", ActType::None) // None 是 Python 关键字,不能用
.value("Relu", ActType::Relu)
.value("Sigmoid", ActType::Sigmoid)
.value("Tanh", ActType::Tanh)
.export_values();
py::class_<GraphHandler>(m, "GraphHandler");
py::class_<GraphHandlerObj>(m, "GraphHandlerObj")
.def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&GraphHandlerObj::tensor),
policy::reference_internal)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&GraphHandlerObj::matmul),
policy::reference_internal);
}
} // namespace infini
PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); }
PYBIND11_MODULE(backend, m) {
infini::register_operator_timer(m);
infini::init_graph_builder(m);
}

2
test/core/python.py Normal file
View File

@ -0,0 +1,2 @@
def inc(x: int) -> int:
return x + 1

View File

@ -0,0 +1,16 @@
#include "core/graph_handler.h"
#include "core/runtime.h"
#include <test.h>
namespace infini {
TEST(Handler, matmul) {
auto runtime = CpuRuntimeObj::getInstance();
auto handler = make_ref<GraphHandlerObj>(runtime);
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32);
handler->matmul(i, w, o, false, false, nullptr, ActType::None);
}
} // namespace infini

16
test/core/test_python.cc Normal file
View File

@ -0,0 +1,16 @@
#include <filesystem>
#include <pybind11/embed.h>
#include <test.h>
TEST(Python, pybind) {
namespace fs = std::filesystem;
namespace py = pybind11;
using mod = py::module;
py::scoped_interpreter _python;
auto sys_path_append = mod::import("sys").attr("path").attr("append");
sys_path_append(fs::path(__FILE__).parent_path().c_str());
auto ans = mod::import("python").attr("inc")(1);
EXPECT_EQ(ans.cast<int>(), 2);
}