forked from jiuyuan/InfiniTensor
update doc (#83)
* update doc * update doc * update doc * update doc * add code * add code * update doc * update doc * add env.sh and update install guide * fix * fix bug * fix * add code * code format * Update exception.cc --------- Co-authored-by: wanghailu <wanghailu@qiyuanlab.com> Co-authored-by: wanghailu <wanghailu0717@163.com>
This commit is contained in:
parent
26f0d13c26
commit
19d7dc871d
|
@ -0,0 +1,126 @@
|
|||
# 安装部署手册
|
||||
|
||||
## 目录
|
||||
|
||||
- [环境准备](#环境准备)
|
||||
- [编译本项目](#编译本项目)
|
||||
- [技术支持](#技术支持)
|
||||
|
||||
## 环境准备
|
||||
|
||||
目前的软硬件环境支持矩阵
|
||||
|
||||
| Host CPU | Device | OS | Support |
|
||||
| -------- | ------------ | ----------- | ---------- |
|
||||
| X86-64 | Nvidia GPU | Ubuntu-22.04 | Yes |
|
||||
| X86-64 | Cambricon MLU | Ubuntu-22.04 | Yes |
|
||||
|
||||
推荐使用 X86-64 机器以及 Ubuntu-22.04,本文以此环境为例。
|
||||
|
||||
1. 确认 GCC 版本为 11.3 及以上的稳定版本,如若您的机器 GCC 版本不满足此条件,请自行编译安装,下述方式二选一:
|
||||
|
||||
> [GCC 官方文档](https://gcc.gnu.org/onlinedocs/gcc-11.3.0/gcc/)
|
||||
|
||||
> [网友安装分享](https://zhuanlan.zhihu.com/p/509695395)
|
||||
|
||||
2. 确认 CMake 版本为 3.17 及以上的稳定版本, 如若您的机器 CMake 版本不满足此条件,请自行编译安装,下述方式二选一:
|
||||
|
||||
> [CMake 官方文档](https://cmake.org/install/)
|
||||
|
||||
> [网友安装分享](https://zhuanlan.zhihu.com/p/110793004)
|
||||
|
||||
3. 第三方加速卡软件资源安装,目前本项目已经适配了如下的第三方加速卡:
|
||||
|
||||
> 如您的第三方加速卡为英伟达 GPU,请参考英伟达官方文档进行:
|
||||
> > [驱动安装](https://www.nvidia.cn/geforce/drivers/),
|
||||
> > [CUDA Toolkit 安装](https://developer.nvidia.com/cuda-toolkit),
|
||||
> > [Cudnn 安装](https://developer.nvidia.com/rdp/cudnn-download),
|
||||
> > [Cublas 安装](https://developer.nvidia.com/cublas),
|
||||
> > 安装完成后请进行相应的环境变量配置,将可执行文件目录与库目录添加到操作系统识别的路径中,
|
||||
我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。
|
||||
|
||||
> 如您的第三方加速卡为寒武纪 MLU,请参考寒武纪官方文档进行:
|
||||
> > [驱动安装](https://www.cambricon.com/docs/sdk_1.11.0/driver_5.10.6/user_guide_5.10.6/index.html),
|
||||
> > [CNToolkit 安装](https://www.cambricon.com/docs/sdk_1.11.0/cntoolkit_3.4.1/cntoolkit_install_3.4.1/index.html),
|
||||
> > [CNNL 安装](https://www.cambricon.com/docs/sdk_1.11.0/cambricon_cnnl_1.16.1/user_guide/index.html),
|
||||
> > 安装完成后请进行相应的环境变量配置,将可执行文件目录与库目录添加到操作系统识别的路径中,例如
|
||||
> > ```bash
|
||||
> > # 将如下内容写入到你的 bashrc 文件并 source 该文件
|
||||
> > export NEUWARE_HOME="/usr/local/neuware"
|
||||
> > export PATH="${NEUWARE_HOME}/bin:${PATH}"
|
||||
> > export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64:${LD_LIBRARY_PATH}"
|
||||
> > ```
|
||||
> > 我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。另外请注意,由于 MLU 上层软件建设适配程度有限,如您在其覆盖的机器,操作系统之外运行,需要在安装驱动之后使用上层软件的 Docker。
|
||||
|
||||
4. 确认您安装了 make,build-essential, python-is-python3, python-dev-is-python3, python3-pip, libdw-dev,如您的机器没有上述基础依赖,请自行按需安装。
|
||||
|
||||
> 在使用 apt-get 工具情况下,您可以这样子执行。
|
||||
|
||||
```bash
|
||||
sudo apt-get install make cmake build-essential python-is-python3 python-dev-is-python3 python3-pip libdw-dev
|
||||
```
|
||||
|
||||
> 其他工具安装方式请自行上网搜寻
|
||||
|
||||
5. 更新pip并切换到清华源
|
||||
|
||||
```bash
|
||||
python -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade pip
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
6. 安装一些不必要的项目(可选)
|
||||
|
||||
> 如您需要运行本项目下的 example 代码,您需要安装一些辅助项目。请注意这些项目不是必要的,若您不需要运行样例代码,这些项目无需安装。
|
||||
> > [Pytorch](https://pytorch.org/get-started/locally/):业界内流行的神经网络编程框架
|
||||
> > [ONNX](https://onnx.ai/get-started.html):业界内流行的神经网络模型存储文件与转换器
|
||||
> > [onnxsim](https://pypi.org/project/onnxsim/):一个简化onnx模型的小工具
|
||||
> > [onnx2torch](https://github.com/ENOT-AutoDL/onnx2torch):一个将onnx模型转换pytorch模型的小工具
|
||||
> > [tqdm](https://pypi.org/project/tqdm/):一个显示程序运行进度条的小工具
|
||||
|
||||
> 如您需要使用本项目下的 InfiniTest 测试工具,你还需要安装如下的项目:
|
||||
> > [protobuf](https://github.com/protocolbuffers/protobuf): 一种序列化文件的格式及其编译、序列化、解析工具
|
||||
|
||||
## 编译本项目
|
||||
|
||||
推荐使用 X86-64 机器以及 Ubuntu-22.04,本文以此环境为例。
|
||||
|
||||
1. 配置环境
|
||||
|
||||
打开 env.sh 文件进行环境变量配置,之后执行
|
||||
|
||||
```bash
|
||||
source env.sh
|
||||
```
|
||||
|
||||
2. 编译本项目并打包成 Python 库进行安装
|
||||
|
||||
我们提供了意见编译参数,您可以在项目根目录下执行下面的命令。第一次执行会同时安装 python 依赖库,耗时略长,请耐心等待
|
||||
|
||||
仅编译 CPU 部分,不编译第三方计算卡:
|
||||
|
||||
```bash
|
||||
make install-python
|
||||
```
|
||||
|
||||
编译 CPU 部分,同时编译英伟达 GPU 部分:
|
||||
|
||||
```bash
|
||||
export CUDA_HOME=/path/to/your/cuda_home
|
||||
make install-python CUDA=ON
|
||||
```
|
||||
|
||||
编译 CPU 部分,同时编译寒武纪 MLU 部分:
|
||||
|
||||
```bash
|
||||
export NEUWARE_HOME=/path/to/your/neuware_home
|
||||
make install-python BANG=ON
|
||||
```
|
||||
|
||||
3. 使用方法
|
||||
|
||||
安装成功后,您就可以使用本项目的 Python 接口进行编码并运行。具体使用方式可以参考项目样例代码 example/Resnet/resnet.py 以及用户使用手册
|
||||
|
||||
## 技术支持
|
||||
|
||||
如遇到问题,请联系我们技术支持团队
|
64
README_CN.md
64
README_CN.md
|
@ -2,50 +2,73 @@
|
|||
|
||||
## 目录
|
||||
|
||||
- [编译](#编译)
|
||||
- [使用](#使用)
|
||||
- [环境准备](#环境准备)
|
||||
- [编译本项目](#编译本项目)
|
||||
- [使用方法](#使用方法)
|
||||
- [python-前端应用指南](#python-前端应用指南)
|
||||
- [导入-onnx-模型](#导入-onnx-模型)
|
||||
- [导出-onnx-模型](#导出-onnx-模型)
|
||||
- [执行推理](#执行推理)
|
||||
- [样例代码](#样例代码)
|
||||
- [技术支持](#技术支持)
|
||||
- [测试](#测试)
|
||||
|
||||
## 编译
|
||||
## 环境准备
|
||||
|
||||
推荐使用 Ubuntu-22.04,本文以此环境为例。
|
||||
推荐使用 X86-64 机器以及 Ubuntu-22.04,本文以此环境为例。
|
||||
|
||||
1. 使用 apt 安装依赖
|
||||
|
||||
> 如果不使用 Ubuntu-22.04,部分软件版本可能不够高。
|
||||
1. 确认 GCC 版本为 11.3 及以上的稳定版本,如若您的机器 GCC 版本不满足此条件,请自行编译安装,下述方式二选一:
|
||||
> [GCC 官方文档](https://gcc.gnu.org/onlinedocs/gcc-11.3.0/gcc/)
|
||||
> [网友安装分享](https://zhuanlan.zhihu.com/p/509695395)
|
||||
2. 确认 CMake 版本为 3.17 及以上的稳定版本, 如若您的机器 CMake 版本不满足此条件,请自行编译安装,下述方式二选一:
|
||||
> [CMake 官方文档](https://cmake.org/install/)
|
||||
> [网友安装分享](https://zhuanlan.zhihu.com/p/110793004)
|
||||
3. 第三方加速卡软件资源安装,目前本项目已经适配了如下的第三方加速卡:
|
||||
> 如您的第三方加速卡为英伟达 GPU,请参考英伟达官方文档进行[驱动安装](https://www.nvidia.cn/geforce/drivers/),[CUDA Toolkit 安装](https://developer.nvidia.com/cuda-toolkit),[Cudnn 安装](https://developer.nvidia.com/rdp/cudnn-download),[Cublas 安装](https://developer.nvidia.com/cublas),我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。
|
||||
> 如您的第三方加速卡为寒武纪 MLU,请参考寒武纪官方文档进行[驱动安装](https://www.cambricon.com/docs/sdk_1.11.0/driver_5.10.6/user_guide_5.10.6/index.html),[CNToolkit 安装](https://www.cambricon.com/docs/sdk_1.11.0/cntoolkit_3.4.1/cntoolkit_install_3.4.1/index.html),[CNNL 安装](https://www.cambricon.com/docs/sdk_1.11.0/cambricon_cnnl_1.16.1/user_guide/index.html),我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。另外请注意,由于 MLU 上层软件建设适配程度有限,如您在其覆盖的机器,操作系统之外运行,需要在安装驱动之后使用上层软件的 Docker。
|
||||
4. 确认您安装了 make,build-essential, python-is-python3, python-dev-is-python3, python3-pip, libdw-dev,如您的机器没有上述基础依赖,请自行按需安装。
|
||||
> 在使用 apt-get 工具情况下,您可以这样子执行。
|
||||
|
||||
```bash
|
||||
sudo apt-get install make cmake build-essential python-is-python3 python-dev-is-python3 python3-pip libdw-dev
|
||||
```
|
||||
|
||||
2. 更新 pip 并换清华源
|
||||
> 其他工具请自行上网搜寻
|
||||
|
||||
5. 更新pip并切换到清华源
|
||||
|
||||
```bash
|
||||
python -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade pip
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
3. 编译并安装 python 库
|
||||
## 编译本项目
|
||||
|
||||
> 第一次执行会同时安装 python 依赖库,比较慢
|
||||
推荐使用 X86-64 机器以及 Ubuntu-22.04,本文以此环境为例。
|
||||
|
||||
仅编译 CPU 部分:
|
||||
1. 编译并安装 python 库
|
||||
|
||||
> 第一次执行会同时安装 python 依赖库,耗时略长,请耐心等待
|
||||
|
||||
仅编译 CPU 部分,不编译第三方计算卡:
|
||||
|
||||
```bash
|
||||
make install-python
|
||||
```
|
||||
|
||||
编译 GPU 部分:
|
||||
编译 CPU 部分,同时编译英伟达 GPU 部分:
|
||||
|
||||
```bash
|
||||
make install-python CUDA=ON
|
||||
```
|
||||
|
||||
## 使用
|
||||
编译 CPU 部分,同时编译寒武纪 MLU 部分:
|
||||
|
||||
```bash
|
||||
make install-python BANG=ON
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
项目管理功能已写到 [Makefile](Makefile),支持下列功能:
|
||||
|
||||
|
@ -168,6 +191,21 @@ for name, tensor in stub.outputs.items():
|
|||
print(tensor.copyout_float())
|
||||
```
|
||||
|
||||
### 样例代码
|
||||
|
||||
您可以参照[./example/Resnet/resnet.py](./example/ResNet/resnet.py)的样例代码进行了解,并尝试运行。在这个文件中,我们使用了 Pytorch 构建了 resnet 网络。您可以查阅该脚本使用方式:
|
||||
|
||||
```python
|
||||
python resnet.py -h
|
||||
```
|
||||
|
||||
在样例代码中,我们对定义的网络进行了序列化操作,并存储为模型文件。之后加载该模型文件,并转换为本项目的模型进行优化操作,再进行推理。您可以关注一下代码中 242 行之后的代码。请注意,您可以按照您的需求来进行操作,通常来说,您所需要撰写的代码就是加载模型,转换为本项目的模型进行优化,推理运行。
|
||||
|
||||
## 技术支持
|
||||
|
||||
如若您遇到了本项目的问题,请联系我们的技术支持团队
|
||||
|
||||
|
||||
## 测试
|
||||
|
||||
除了单元测试 `make test-cpp` 和 `make test-onnx` 之外,还可以用其他方式来测试单个模型导入导出和优化的正确性。
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# 支持矩阵
|
||||
|
||||
## 目录
|
||||
|
||||
- [环境支持](#环境支持)
|
||||
- [神经网络支持](#神经网络支持)
|
||||
- [技术支持](#技术支持)
|
||||
|
||||
## 环境支持
|
||||
|
||||
目前的软硬件环境支持矩阵
|
||||
|
||||
| Host CPU | Device | OS | Support |
|
||||
| -------- | ------------ | ----------- | ---------- |
|
||||
| X86-64 | Nvidia GPU | Ubuntu-22.04 | Yes |
|
||||
| X86-64 | Cambricon MLU | Ubuntu-22.04 | Yes |
|
||||
|
||||
## 神经网络支持
|
||||
|
||||
目前已经验证过的神经网络模型有
|
||||
|
||||
- [x] [ResNet18-v2](https://github.com/onnx/models/blob/main/vision/classification/resnet/model/resnet18-v2-7.onnx)
|
||||
- [x] [DenseNet-121-12](https://github.com/onnx/models/blob/main/vision/classification/densenet-121/model/densenet-12.onnx)
|
||||
- [x] [Inception-2](https://github.com/onnx/models/blob/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx)
|
||||
- [x] [EfficientNet-Lite4](https://github.com/onnx/models/blob/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx)
|
||||
|
||||
## 技术支持
|
||||
|
||||
如若您遇到了本项目的问题,请联系我们的技术支持团队
|
|
@ -0,0 +1,216 @@
|
|||
# 使用指南
|
||||
|
||||
## 目录
|
||||
|
||||
- [项目简介](#项目简介)
|
||||
- [项目设计](#项目设计)
|
||||
- [使用方法](#使用方法)
|
||||
- [python-前端应用指南](#python-前端应用指南)
|
||||
- [导入-onnx-模型](#导入-onnx-模型)
|
||||
- [导出-onnx-模型](#导出-onnx-模型)
|
||||
- [执行推理](#执行推理)
|
||||
- [样例代码](#样例代码)
|
||||
- [技术支持](#技术支持)
|
||||
- [测试](#测试)
|
||||
|
||||
## 项目简介
|
||||
|
||||
本项目是深度学习领域的一个编译器集合,本项目旨在缩小深度学习应用与后端硬件之间的鸿沟。本项目通过使用编译器超优化技术,对神经网络模型进行优化,从而获得更好的性能。同时,本项目与深度学习框架相互配合,为不同的硬件后端提供端倒端的编译,方便用户迁移部署。
|
||||
|
||||
## 项目设计
|
||||
|
||||
本项目的设计是前后端解耦合的,主要有三个模块,分别为:
|
||||
|
||||
- Runtime 模块:该模式负责对不同的加速卡后端进行包装与支持,支撑后端运行。另外提供统一的向上接口,方便上层建设。
|
||||
- Compiler 模块:该模式负责对神经网络模型进行优化变换,获得更加高效的等价模型。
|
||||
- Interface 模块:该模式负责给用户提供编程与交互的接口,方便用户使用本系统。
|
||||
|
||||
## 使用方法
|
||||
|
||||
项目管理功能已写到 [Makefile](Makefile),支持下列功能:
|
||||
|
||||
- 编译项目:`make`/`make build`
|
||||
- 清理生成文件:`make clean`
|
||||
- 安装 python 库:`make install-python`
|
||||
- 测试 c++ 后端:`make test-cpp`
|
||||
- 测试 python 前端:`make test-onnx`
|
||||
|
||||
并使用下列环境变量传递选项参数:
|
||||
|
||||
- `TYPE`:编译模式(`debug`/`release`),默认值为 `release`
|
||||
- `CUDA`:是否编译 CUDA 后端,默认为 `OFF`,`ON` 打开
|
||||
- `BANG`:是否编译寒武纪后端,默认为 `OFF`,`ON` 打开
|
||||
- `BACKTRACE`:是否启用栈回溯,默认为 `ON`,`OFF` 关闭,建议调试时打开
|
||||
- `TEST`:是否编译 `googletest`,默认为 `ON`,`OFF` 关闭,只有 `test-cpp` 时必要
|
||||
|
||||
## python 前端应用指南
|
||||
|
||||
`make install-python` 会将项目的 python 前端以 `pyinfinitensor` 为名字安装到系统目录,可以直接 `import pyinfinitensor` 来使用。现阶段,项目的主要用法是从 onnx 导入模型进行优化,然后可以再导出优化后的模型到 onnx,也可以直接运行推理。
|
||||
|
||||
### 导入 onnx 模型
|
||||
|
||||
支持的模型:
|
||||
|
||||
- [x] [ResNet18-v2](https://github.com/onnx/models/blob/main/vision/classification/resnet/model/resnet18-v2-7.onnx)
|
||||
- [x] [DenseNet-121-12](https://github.com/onnx/models/blob/main/vision/classification/densenet-121/model/densenet-12.onnx)
|
||||
- [x] [Inception-2](https://github.com/onnx/models/blob/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx)
|
||||
- [x] [EfficientNet-Lite4](https://github.com/onnx/models/blob/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx)
|
||||
|
||||
```python
|
||||
import onnx
|
||||
from pyinfinitensor.onnx import OnnxStub
|
||||
from pyinfinitensor import backend
|
||||
|
||||
stub = OnnxStub(onnx.load("model_file"), backend.cpu_runtime())
|
||||
```
|
||||
|
||||
[`onnx.load`](https://onnx.ai/onnx/api/serialization.html#load-a-model) 是 onnx 提供的加载函数,将 onnx 文件读取为保存在内存中的 onnx 模型。
|
||||
|
||||
`OnnxStub` 是 onnx 模型在项目中的表示,通过构造这个对象,将 onnx 模型导入到项目中。其构造器的第一个参数是 onnx 模型文件;第二个参数是模型运行的后端运行时,可以是 `backend.cpu_runtime()`、`backend.cuda_runtime()` 或 `backend.bang_runtime()`。
|
||||
|
||||
构造出的 stub 对象可以用于操作项目中的模型和运行时。
|
||||
|
||||
### 优化
|
||||
|
||||
TODO
|
||||
|
||||
### 导出 onnx 模型
|
||||
|
||||
优化后的模型可以导出成 onnx 文件提供给其他运行时。
|
||||
|
||||
```python
|
||||
with open("optimized.onnx", "wb") as f:
|
||||
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||
```
|
||||
|
||||
`stub.to_onnx(<name>)` 将模型转换为 onnx 模型对象,`<name>` 将填写到 onnx 模型的 `name` 字段。序列化到文件的代码见[官方示例](https://onnx.ai/onnx/intro/python.html#model-serialization)。
|
||||
|
||||
要可视化检查导出的模型文件,可以利用 [onnx 提供的功能](https://onnx.ai/onnx/api/shape_inference.html#infer-shapes)将所有的张量的形状推理出来再导出:
|
||||
|
||||
```python
|
||||
from onnx.shape_inference import infer_shapes
|
||||
|
||||
with open("optimized.onnx", "wb") as f:
|
||||
f.write(infer_shapes(stub.to_onnx("optimized")).SerializeToString())
|
||||
```
|
||||
|
||||
然后用 [Netron](https://netron.app/) 绘制计算图。
|
||||
|
||||
### 执行推理
|
||||
|
||||
也可以使用项目的运行时执行推理。
|
||||
|
||||
第一步是将数据传入计算图。`OnnxStub.inputs` 是一个 `Dict[str, Tensor]`,保存着模型的所有输入的名字和对象。可以用 [`items()`](https://docs.python.org/zh-cn/3/library/stdtypes.html#dict.items) 来遍历。
|
||||
|
||||
这个代码片段显示了如何打印出模型所有输入张量的名字、形状和对象指针:
|
||||
|
||||
```python
|
||||
for name, tensor in stub.inputs.items():
|
||||
print(name, tensor.shape(), tensor)
|
||||
```
|
||||
|
||||
对于 [resnet18-v2-7.onnx](https://github.com/onnx/models/blob/main/vision/classification/resnet/model/resnet18-v2-7.onnx),会打印出:
|
||||
|
||||
```plaintext
|
||||
data [1, 3, 224, 224] <backend.Tensor object at 0x7efeb828e3b0>
|
||||
```
|
||||
|
||||
当然,地址是随机的。这个输出表明需要输入一个名为 “data”,形为 1×3×224×224 的数据。通常来说,这表示一张 224×224 的 rgb 图片。而这个模型是一个 1000 分类的图像分类模型。
|
||||
|
||||
为了方便,这里我们向模型传入一个随机的数据。
|
||||
|
||||
```python
|
||||
import numpy
|
||||
|
||||
stub.init()
|
||||
for name, tensor in stub.inputs.items():
|
||||
print(name, tensor.shape(), tensor)
|
||||
input = numpy.random.random(tensor.shape()).astype(numpy.float32)
|
||||
tensor.copyin_float(input.flatten().tolist())
|
||||
```
|
||||
|
||||
`stub.init()` 为所有张量分配空间。空间是预分配的,所以不支持动态 size 的模型。
|
||||
|
||||
`tensor.copyin_float(<data>)` 向张量传入数据。其参数必须是一个 `List[float]`,即压平的数据。类似的函数还有 `copyin_int32(<data>)` 和 `copyin_int64(<data>)`
|
||||
|
||||
然后,调用 `stub.run()` 执行推理:
|
||||
|
||||
```python
|
||||
stub.run()
|
||||
```
|
||||
|
||||
最后,将结果拷贝出来,传入类似:
|
||||
|
||||
```python
|
||||
stub.init()
|
||||
for name, tensor in stub.outputs.items():
|
||||
print(name, tensor.shape(), tensor)
|
||||
print(tensor.copyout_float())
|
||||
```
|
||||
|
||||
### 样例代码
|
||||
|
||||
您可以参照[./example/Resnet/resnet.py](./example/ResNet/resnet.py)的样例代码进行了解,并尝试运行。在这个文件中,我们使用了 Pytorch 构建了 resnet 网络。您可以查阅该脚本使用方式:
|
||||
|
||||
```python
|
||||
python resnet.py -h
|
||||
```
|
||||
|
||||
在样例代码中,我们对定义的网络进行了序列化操作,并存储为模型文件。之后加载该模型文件,并转换为本项目的模型进行优化操作,再进行推理。您可以关注一下代码中 242 行之后的代码。请注意,您可以按照您的需求来进行操作,通常来说,您所需要撰写的代码就是加载模型,转换为本项目的模型进行优化,推理运行。
|
||||
|
||||
## 技术支持
|
||||
|
||||
如若您遇到了本项目的问题,请联系我们的技术支持团队
|
||||
|
||||
|
||||
## 测试
|
||||
|
||||
除了单元测试 `make test-cpp` 和 `make test-onnx` 之外,还可以用其他方式来测试单个模型导入导出和优化的正确性。
|
||||
|
||||
这个脚本利用 onnxruntime 来测试导出的模型是否与导入的模型等价:
|
||||
|
||||
```python
|
||||
import onnx
|
||||
import numpy
|
||||
import sys
|
||||
from onnx import ModelProto, ValueInfoProto
|
||||
from pyinfinitensor.onnx import OnnxStub
|
||||
from pyinfinitensor import backend
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
|
||||
def infer(model: ModelProto, input) -> dict:
|
||||
collection = set()
|
||||
for node in model.graph.node:
|
||||
for output in node.output:
|
||||
collection.add(output)
|
||||
model.graph.output.extend([ValueInfoProto(name=x) for x in collection])
|
||||
session = InferenceSession(model.SerializeToString())
|
||||
i = session.get_inputs()[0].name
|
||||
return dict(
|
||||
zip(
|
||||
[x.name for x in session.get_outputs()],
|
||||
[x.flatten() for x in session.run(None, {i: input})],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
model0 = onnx.load(sys.argv[1])
|
||||
model1 = OnnxStub(model0, backend.cpu_runtime()).to_onnx("new")
|
||||
|
||||
input_shape = [x.dim_value for x in model1.graph.input[0].type.tensor_type.shape.dim]
|
||||
input = numpy.random.random(input_shape).astype(numpy.float32)
|
||||
|
||||
output0 = infer(model0, input)[model0.graph.output[0].name]
|
||||
output1 = infer(model1, input)[model1.graph.output[0].name]
|
||||
|
||||
print("error =", sum((output1 - output0) ** 2) / len(output0))
|
||||
```
|
||||
|
||||
要运行脚本,先安装 onnxruntime:
|
||||
|
||||
```bash
|
||||
pip install onnxruntime
|
||||
```
|
||||
|
||||
打印出的 `error = ...` 是两个模型输出张量的均方误差。对于不同的模型,这个误差最小为 0,最大不超过 1e-9。
|
|
@ -0,0 +1,34 @@
|
|||
# 配置英伟达 CUDA 的 HOME 路径,请注意安装 CUDA Toolkit, CUDNN 并将路径配置到下述环境变量。
|
||||
export CUDA_HOME=/PATH/TO/YOUR/CUDA/HOME
|
||||
export CUDNN_HOME=/PATH/TO/YOUR/CUDNN/HOME
|
||||
|
||||
# 配置寒武纪 BANG 的 HOME 路径,请注意 /usr/local/neuware 是寒武纪软件栈建议的,同时也是默认的安装路径。
|
||||
# 如若用户有其他的路径安装方式,请自行配置正确的路径。
|
||||
# 这里是 neuware 目录下一个可能的结构图,请参考。
|
||||
# .
|
||||
# ├── bin
|
||||
# ├── cmake
|
||||
# ├── data
|
||||
# ├── edge
|
||||
# ├── include
|
||||
# ├── lib
|
||||
# ├── lib64
|
||||
# ├── LICENSE
|
||||
# ├── mlvm
|
||||
# ├── README
|
||||
# ├── samples
|
||||
# ├── share
|
||||
# └── version.txt
|
||||
export NEUWARE_HOME=/usr/local/neuware
|
||||
|
||||
# 配置昆仑芯 XPU 的 HOME 路径,请注意 /usr/local/xpu 是昆仑芯软件栈提供的软件包路径。
|
||||
# 如若用户有其他的路径安装方式,请自行配置正确的路径。
|
||||
# 这里是 xpu 目录下一个可能的结构图,请参考。
|
||||
# .
|
||||
# ├── bin
|
||||
# ├── include
|
||||
# ├── lib64
|
||||
# ├── tools
|
||||
# ├── version
|
||||
# └── XTDK
|
||||
export XPU_HOME=/usr/local/xpu
|
|
@ -0,0 +1,71 @@
|
|||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "operators/batch_norm.h"
|
||||
|
||||
namespace infini {
|
||||
class BatchNormCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1];
|
||||
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dimArray[i] = dims[i];
|
||||
strideArray[i] = op->getInputs(0)->getStride()[i];
|
||||
}
|
||||
int w = dimArray[3];
|
||||
dimArray[3] = dimArray[1];
|
||||
int h = dimArray[2];
|
||||
dimArray[1] = h;
|
||||
dimArray[2] = w;
|
||||
|
||||
dimPArray[0] = op->getInputs(1)->getDims()[0];
|
||||
stridePArray[0] = op->getInputs(1)->getDims()[0];
|
||||
// get inputs
|
||||
cnnlTensorDescriptor_t inDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC,
|
||||
CNNL_DTYPE_FLOAT, dims.size(),
|
||||
dimArray, strideArray));
|
||||
|
||||
// get bnScaleBiasMeanVarDesc
|
||||
cnnlTensorDescriptor_t paraDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(¶Desc));
|
||||
checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 1, dimPArray,
|
||||
stridePArray));
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
// This mode is intended for use after convolutional layers
|
||||
cnnlStatus_t stat = cnnlBatchNormForwardInference(
|
||||
context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc,
|
||||
scale, bias, mean, var, op->getEps(), inDesc, output);
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in BANG does not require sync. But cnnl does not state
|
||||
// whether sync is required before destories.
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNorm, DataType::Float32,
|
||||
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -18,24 +18,26 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
|
|||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
|
||||
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
// get inputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, a_dim.data()));
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, b_dim.data()));
|
||||
|
||||
// get outputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, c_dim.data()));
|
||||
|
||||
// get op descriptor
|
||||
cnnlOpTensorDescriptor_t opDesc;
|
||||
|
|
|
@ -18,45 +18,53 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
|||
auto dimInputs0 = op->getInputs(0)->getDims();
|
||||
auto dimInputs1 = op->getInputs(1)->getDims();
|
||||
auto dimOutput = op->getOutput()->getDims();
|
||||
int input0_batch_size = 1;
|
||||
int input1_batch_size = 1;
|
||||
int output_batch_size = 1;
|
||||
for (size_t i = 0; i < dimInputs0.size() - 2; ++i) {
|
||||
input0_batch_size *= dimInputs0[i];
|
||||
input1_batch_size *= dimInputs1[i];
|
||||
output_batch_size *= dimOutput[i];
|
||||
}
|
||||
|
||||
bool transA = op->getTransA();
|
||||
bool transB = op->getTransB();
|
||||
|
||||
int inputs0Array[3] = {input0_batch_size,
|
||||
dimInputs0[dimInputs0.size() - 2],
|
||||
dimInputs0[dimInputs0.size() - 1]};
|
||||
int inputs1Array[3] = {input1_batch_size,
|
||||
dimInputs1[dimInputs1.size() - 2],
|
||||
dimInputs1[dimInputs1.size() - 1]};
|
||||
int outputArray[3] = {output_batch_size,
|
||||
dimOutput[dimOutput.size() - 2],
|
||||
dimOutput[dimOutput.size() - 1]};
|
||||
int32_t transA = op->getTransA();
|
||||
int32_t transB = op->getTransB();
|
||||
|
||||
// get inputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 3, inputs0Array));
|
||||
checkCnnlError(
|
||||
cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
||||
dimInputs0.size(), dimInputs0.data()));
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 3, inputs1Array));
|
||||
checkCnnlError(
|
||||
cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
||||
dimInputs1.size(), dimInputs1.data()));
|
||||
|
||||
// get outputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 3, outputArray));
|
||||
checkCnnlError(
|
||||
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
||||
dimOutput.size(), dimOutput.data()));
|
||||
|
||||
cnnlStatus_t stat =
|
||||
cnnlBatchMatMul(context->cnnlHandle(), transA, transB, aDesc, aData,
|
||||
bDesc, bData, cDesc, cData);
|
||||
cnnlMatMulDescriptor_t bmm_desc;
|
||||
cnnlMatMulDescCreate(&bmm_desc);
|
||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSA, &transA,
|
||||
sizeof(int32_t));
|
||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
|
||||
sizeof(int32_t));
|
||||
|
||||
cnnlMatMulAlgo_t bmm_algo;
|
||||
cnnlMatMulAlgoCreate(&bmm_algo);
|
||||
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
int count = 0;
|
||||
|
||||
cnnlMatMulHeuristicResult_t desc;
|
||||
cnnlCreateMatMulHeuristicResult(&desc);
|
||||
|
||||
cnnlGetBatchMatMulAlgoHeuristic(context->cnnlHandle(), bmm_desc, aDesc,
|
||||
bDesc, cDesc, NULL, 1, &desc, &count);
|
||||
size_t wsSize;
|
||||
cnnlGetBatchMatMulHeuristicResult(desc, bmm_algo, &wsSize);
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
cnnlStatus_t stat = cnnlBatchMatMulBCast_v2(
|
||||
context->cnnlHandle(), bmm_desc, bmm_algo, &alpha, aDesc, aData,
|
||||
bDesc, bData, &beta, cDesc, cData, wsData, wsSize);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
@ -65,6 +73,9 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
|||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
checkCnnlError(cnnlMatMulDescDestroy(bmm_desc));
|
||||
checkCnnlError(cnnlMatMulAlgoDestroy(bmm_algo));
|
||||
checkCnnlError(cnnlDestroyMatMulHeuristicResult(desc));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -4,9 +4,8 @@
|
|||
|
||||
namespace infini {
|
||||
class CopyBang : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
void compute(const Operator &op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReshapeObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
auto inData = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#ifdef BACKWARD_TRACE
|
||||
#include "backward.hpp"
|
||||
|
||||
namespace backtrace = backward;
|
||||
namespace backward_trace = backward;
|
||||
|
||||
// signal handler
|
||||
backtrace::SignalHandling sh;
|
||||
backward_trace::SignalHandling sh;
|
||||
|
||||
namespace infini {
|
||||
Exception::Exception(const std::string &msg) : std::runtime_error(msg) {
|
||||
backtrace::StackTrace st;
|
||||
backward_trace::StackTrace st;
|
||||
st.load_here(32);
|
||||
backtrace::Printer p;
|
||||
backward_trace::Printer p;
|
||||
p.print(st);
|
||||
}
|
||||
}; // namespace infini
|
||||
|
|
Loading…
Reference in New Issue