forked from jiuyuan/InfiniTensor
add onnx simplify (#153)
* add onnx simplify * fix test bug * update ci policy * fix onnx simpilfy bug * update ci workflow
This commit is contained in:
parent
7f16fa353e
commit
90b9a80f72
|
@ -1,12 +1,11 @@
|
|||
name: Build and test cpu
|
||||
on:
|
||||
push:
|
||||
branch: 'master'
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'LICENSE'
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
paths:
|
||||
- '**.md'
|
||||
- 'LICENSE'
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
name: clang-format Check
|
||||
on:
|
||||
push:
|
||||
branch: 'master'
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'LICENSE'
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
paths:
|
||||
- '**.md'
|
||||
- 'LICENSE'
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ requires-python = ">=3.7"
|
|||
keywords = ["optimizer"]
|
||||
license = { text = "Apache" }
|
||||
classifiers = ["Programming Language :: Python :: 3"]
|
||||
dependencies = ["onnx"]
|
||||
dependencies = ["onnx","onnx-simplifier"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
|
|
@ -20,11 +20,14 @@ from onnx.checker import (
|
|||
check_node,
|
||||
check_value_info,
|
||||
check_tensor,
|
||||
ValidationError,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.numpy_helper import to_array
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
from onnxsim import simplify
|
||||
import copy
|
||||
|
||||
|
||||
class OnnxStub:
|
||||
|
@ -33,6 +36,14 @@ class OnnxStub:
|
|||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
model_simp, check = simplify(copy.deepcopy(model))
|
||||
if check:
|
||||
model = model_simp
|
||||
except ValidationError:
|
||||
pass
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.initializer: Dict[int, TensorProto] = {}
|
||||
|
|
Loading…
Reference in New Issue