diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7117b1fa..7dae8509 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,8 +6,11 @@ on: - 'LICENSE' env: - protobuf_download: https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protobuf-cpp-3.21.12.tar.gz - protobuf_version: 3.21.12 + protobuf-download: https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protobuf-cpp-3.21.12.tar.gz + protobuf-version: "3.21.12" + python-version: "3.10" + resnet-download: https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx + resnet-file: resnet18-v2-7.onnx jobs: build: @@ -18,6 +21,11 @@ jobs: with: submodules: recursive + - name: Set up Python ${{ env.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ env.python-version }} + - name: Install libdw run: sudo apt-get update && sudo apt-get install libdw-dev @@ -25,33 +33,38 @@ jobs: id: cache-protobuf uses: actions/cache@v3 with: - path: protobuf-${{ env.protobuf_version }} - key: protobuf-${{ env.protobuf_version }} + path: protobuf-${{ env.protobuf-version }} + key: protobuf-${{ env.protobuf-version }} - name: Download and compile protobuf if: steps.cache-protobuf.outputs.cache-hit != 'true' run: | - wget ${{ env.protobuf_download }} - tar xf protobuf-cpp-${{ env.protobuf_version }}.tar.gz - cd protobuf-${{ env.protobuf_version }} + wget ${{ env.protobuf-download }} + tar xf protobuf-cpp-${{ env.protobuf-version }}.tar.gz + cd protobuf-${{ env.protobuf-version }} ./autogen.sh ./configure CFLAGS="-fPIC" CXXFLAGS="-fPIC" - make -j2 + make -j8 - name: Install protobuf run: | - cd protobuf-${{ env.protobuf_version }} + cd protobuf-${{ env.protobuf-version }} sudo make install sudo ldconfig - name: Build - run: | - mkdir build - cd build - cmake .. - make -j2 + run: make - name: Test cpu + run: make test-cpp + + - name: Install python-frontend run: | - cd build - make test + python -m pip install --upgrade pip + make install-python + + - name: Download test model + run: wget ${{ env.resnet-download }} + + - name: Test onnx frontend + run: make test-onnx diff --git a/.github/workflows/clang-format-check.yml b/.github/workflows/clang-format-check.yml index 536c74b7..35828a29 100644 --- a/.github/workflows/clang-format-check.yml +++ b/.github/workflows/clang-format-check.yml @@ -1,10 +1,12 @@ name: clang-format Check on: pull_request: - paths-ignore: - - '**.md' - - 'LICENSE' - + paths: + - 'include/**' + - 'src/**' + - 'test/**' + - '3rd-party/**' + - '.clang-format' jobs: formatting-check: name: Formatting Check diff --git a/.gitignore b/.gitignore index 840dbccb..863f1a48 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,8 @@ build_debug/ .vscode/ # python -*.pyc \ No newline at end of file +*.egg-info +*.pyc + +# onnx model +*.onnx diff --git a/CMakeLists.txt b/CMakeLists.txt index 52f4ff4e..786f9078 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,8 +96,8 @@ target_link_libraries(InfiniTensor pybind11::embed) # Python bindings file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc) -pybind11_add_module(pyinfinitensor MODULE ${FFIS}) -target_link_libraries(pyinfinitensor PRIVATE InfiniTensor) +pybind11_add_module(backend MODULE ${FFIS}) +target_link_libraries(backend PRIVATE InfiniTensor) if(USE_BACKTRACE) add_definitions(-D BACKWARD_TRACE) @@ -134,13 +134,13 @@ if(USE_BANG) set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development") endif() message(STATUS "NEUWARE_HOME: ${NEUWARE_HOME}") - + include_directories("${NEUWARE_HOME}/include") find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") - + if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE) set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH") diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..71033bbb --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +.PHONY : build clean install-python test-cpp test-onnx + +TYPE ?= release + +build: + mkdir -p build/$(TYPE) + cd build/$(TYPE) && cmake -DCMAKE_BUILD_TYPE=$(TYPE) ../.. && make -j8 + +clean: + rm -rf build + +install-python: build + cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor + pip install pyinfinitensor/ + +test-cpp: build + @echo + cd build/$(TYPE) && make test + +test-onnx: + @echo + python3 pyinfinitensor/tests/test_onnx.py diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h new file mode 100644 index 00000000..9fe92940 --- /dev/null +++ b/include/core/graph_handler.h @@ -0,0 +1,44 @@ +#pragma once + +#include "core/graph.h" +#include "core/runtime.h" + +namespace infini { + +// 借用 onnx 的定义减小传递开销,来自 onnx 但不只用于 onnx。 +// +// see https://onnx.ai/onnx/intro/concepts.html#element-type +enum OnnxDType : int { + UNDEFINED = 0, + FLOAT, + UINT8, + INT8, + UINT16, + INT16, + INT32, + INT64, + STRING, + BOOL, + FLOAT16, + DOUBLE, + UINT32, + UINT64, + COMPLEX64, + COMPLEX128, + BFLOAT16, +}; + +class GraphHandlerObj { + Graph g; + + public: + GraphHandlerObj(Runtime runtime) + : g(make_ref(std::move(runtime))) {} + + Tensor tensor(Shape dims, int dtype); + + Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, + Tensor bias, ActType act); +}; + +} // namespace infini diff --git a/include/core/runtime.h b/include/core/runtime.h index 6aa3c4fe..10103b4d 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -10,6 +10,7 @@ class TensorBaseObj; class TensorObj; class OperatorObj; class GraphObj; +class GraphHandlerObj; class RuntimeObj; class BlobObj; @@ -17,6 +18,7 @@ using TensorBase = Ref; using Tensor = Ref; using Operator = Ref; using Graph = Ref; +using GraphHandler = Ref; using Runtime = Ref; using Blob = Ref; enum class OpType; diff --git a/pyinfinitensor/README.md b/pyinfinitensor/README.md new file mode 100644 index 00000000..0f08921e --- /dev/null +++ b/pyinfinitensor/README.md @@ -0,0 +1,16 @@ +# pyinfinitensor + +InfiniTensor 的 Python 前端。 + +- [ ] 从 ONNX 导入模型 +- [ ] 从 Pytorch 导入模型 +- [ ] 从 PaddlePaddle 导入模型 +- [ ] 模型导出到 ONNX +- [ ] 模型导出到 Pytorch +- [ ] 模型导出到 PaddlePaddle + +## python 工程结构及打包方法 + +本项目使用 [pyproject.toml] 文件定义,目录结构采用 [src 布局](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#src-layout),后端库会被[打包](https://setuptools.pypa.io/en/latest/userguide/datafiles.html#package-data),并支持[自动的依赖安装](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#declaring-required-dependency)。 + +参考 [setuptools 的文档](https://setuptools.pypa.io/en/latest/userguide/index.html)。 diff --git a/pyinfinitensor/docs/placeholder.md b/pyinfinitensor/docs/placeholder.md new file mode 100644 index 00000000..5f282702 --- /dev/null +++ b/pyinfinitensor/docs/placeholder.md @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/pyinfinitensor/pyproject.toml b/pyinfinitensor/pyproject.toml new file mode 100644 index 00000000..aca83ee3 --- /dev/null +++ b/pyinfinitensor/pyproject.toml @@ -0,0 +1,21 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "pyinfinitensor" +version = "0.0.0" +authors = [{ name = "YdrMaster", email = "ydrml@hotmail.com" }] +description = "Python frontend of InfiniTensor" +readme = "README.md" +requires-python = ">=3.8" +keywords = ["optimizer"] +license = { text = "Apache" } +classifiers = ["Programming Language :: Python :: 3"] +dependencies = ["onnx"] + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +pyinfinitensor = ["*.so"] diff --git a/pyinfinitensor/src/pyinfinitensor/__init__.py b/pyinfinitensor/src/pyinfinitensor/__init__.py new file mode 100644 index 00000000..c9558bfe --- /dev/null +++ b/pyinfinitensor/src/pyinfinitensor/__init__.py @@ -0,0 +1,8 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(str(__file__)).parent)) + +import backend + +print("import backend: {}".format(backend)) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py new file mode 100644 index 00000000..ff023914 --- /dev/null +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -0,0 +1,66 @@ +import onnx +import backend + +runtime = backend.cpu_runtime() + + +def from_onnx(model: onnx.ModelProto): + handler = backend.GraphHandlerObj(runtime) + + tensors = dict() + + for input in model.graph.input: + dims = [d.dim_value for d in input.type.tensor_type.shape.dim] + tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type) + + for output in model.graph.output: + dims = [d.dim_value for d in output.type.tensor_type.shape.dim] + tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type) + + for node in model.graph.node: + if node.op_type == "MatMul": + handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors[node.output[0]], + False, + False, + None, + backend.ActType.Linear, + ) + + +def parse_onnx(model: onnx.ModelProto): + print() + + for field in [ + "doc_string", + "domain", + "functions", + "metadata_props", + "model_version", + "producer_name", + "producer_version", + "training_info", + ]: + print("{}: {}".format(field, getattr(model, field))) + + print("ir_version:", model.ir_version) + for opset in model.opset_import: + print("opset domain={} version={}".format(opset.domain, opset.version)) + + print("layout:") + for node in model.graph.node: + print( + ' {o} <- {op}"{name}"{a} <- {i}'.format( + name=node.name, + op=node.op_type, + i=node.input, + o=node.output, + a=[a.name for a in node.attribute], + ) + ) + + print("weight:") + for node in model.graph.initializer: + print(" {}".format(node.name)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py new file mode 100644 index 00000000..3178ebd2 --- /dev/null +++ b/pyinfinitensor/tests/test_onnx.py @@ -0,0 +1,43 @@ +import os, onnx, unittest +from onnx import TensorProto +from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info +from onnx.checker import check_model +from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime + + +class TestStringMethods(unittest.TestCase): + def test_load(self): + model_file = next( + (name for name in os.listdir() if name.endswith(".onnx")), None + ) + if model_file != None: + print( + "model: {file}({size:.2f} MiB)".format( + file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 + ) + ) + parse_onnx(onnx.load(model_file)) + + def test_import(self): + i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 2, 3]) + w = make_tensor_value_info("w", TensorProto.FLOAT, [1, 3, 4]) + o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 4]) + matmul = make_node("MatMul", ["i", "w"], ["o"], name="matmul") + graph = make_graph([matmul], "mm", [i, w], [o]) + model = make_model(graph) + check_model(model) + print(model) + + from_onnx(model) + parse_onnx(model) + + def test_frontend(self): + handler = backend.GraphHandlerObj(runtime) + i = handler.tensor([1, 2, 3], 12) + w = handler.tensor([1, 3, 4], 12) + o = handler.tensor([1, 2, 4], 12) + handler.matmul(i, w, o, False, False, None, backend.ActType.Relu) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc new file mode 100644 index 00000000..ffc68473 --- /dev/null +++ b/src/core/graph_handler.cc @@ -0,0 +1,37 @@ +#include "core/graph_handler.h" +#include "operators/matmul.h" + +namespace infini { + +static DataType dtype_repr_convert(int); + +Tensor GraphHandlerObj::tensor(Shape dims, int dtype) { + return g->addTensor(std::move(dims), dtype_repr_convert(dtype)); +} + +Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, + bool transB, Tensor bias, ActType act) { + if (y) { + g->addOpWithOutputs(std::move(a), std::move(b), y, transA, + transB, std::move(bias), act); + return y; + } else { + return g + ->addOp(std::move(a), std::move(b), y, transA, transB, + std::move(bias), act) + ->getOutput(); + } +} + +static DataType dtype_repr_convert(int dtype) { + switch ((OnnxDType)dtype) { + case OnnxDType::FLOAT: + return DataType::Float32; + case OnnxDType::UINT32: + return DataType::UInt32; + default: + IT_ASSERT(false, "Unsupported data type"); + } +} + +} // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 6df1895a..b8d4354e 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,7 +1,10 @@ +#include "core/graph_handler.h" #include + #ifdef USE_CUDA #include "cuda/operator_timer.h" #endif + namespace py = pybind11; namespace infini { @@ -18,6 +21,34 @@ void register_operator_timer(py::module &m) { #endif } +void init_graph_builder(py::module &m) { + m.def("cpu_runtime", &CpuRuntimeObj::getInstance); + py::class_>(m, "RuntimeObj"); + py::class_, RuntimeObj>( + m, "CpuRuntimeObj"); + py::class_(m, "Shape"); + py::class_>(m, "TensorObj"); + py::class_(m, "Tensor"); + py::enum_(m, "ActType") + .value("Linear", ActType::None) // None 是 Python 关键字,不能用 + .value("Relu", ActType::Relu) + .value("Sigmoid", ActType::Sigmoid) + .value("Tanh", ActType::Tanh) + .export_values(); + py::class_(m, "GraphHandler"); + py::class_(m, "GraphHandlerObj") + .def(py::init()) + .def("tensor", py::overload_cast(&GraphHandlerObj::tensor), + policy::reference_internal) + .def("matmul", + py::overload_cast(&GraphHandlerObj::matmul), + policy::reference_internal); +} + } // namespace infini -PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); } \ No newline at end of file +PYBIND11_MODULE(backend, m) { + infini::register_operator_timer(m); + infini::init_graph_builder(m); +} diff --git a/test/core/python.py b/test/core/python.py new file mode 100644 index 00000000..09159704 --- /dev/null +++ b/test/core/python.py @@ -0,0 +1,2 @@ +def inc(x: int) -> int: + return x + 1 diff --git a/test/core/test_graph_handler.cc b/test/core/test_graph_handler.cc new file mode 100644 index 00000000..cac81dab --- /dev/null +++ b/test/core/test_graph_handler.cc @@ -0,0 +1,16 @@ +#include "core/graph_handler.h" +#include "core/runtime.h" +#include + +namespace infini { + +TEST(Handler, matmul) { + auto runtime = CpuRuntimeObj::getInstance(); + auto handler = make_ref(runtime); + auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32); + auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32); + auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32); + handler->matmul(i, w, o, false, false, nullptr, ActType::None); +} + +} // namespace infini diff --git a/test/core/test_python.cc b/test/core/test_python.cc new file mode 100644 index 00000000..a3933387 --- /dev/null +++ b/test/core/test_python.cc @@ -0,0 +1,16 @@ +#include +#include +#include + +TEST(Python, pybind) { + namespace fs = std::filesystem; + namespace py = pybind11; + using mod = py::module; + + py::scoped_interpreter _python; + + auto sys_path_append = mod::import("sys").attr("path").attr("append"); + sys_path_append(fs::path(__FILE__).parent_path().c_str()); + auto ans = mod::import("python").attr("inc")(1); + EXPECT_EQ(ans.cast(), 2); +}