forked from jiuyuan/InfiniTensor
add onnx simplify (#153)
* add onnx simplify * fix test bug * update ci policy * fix onnx simpilfy bug * update ci workflow
This commit is contained in:
parent
7f16fa353e
commit
90b9a80f72
|
@ -1,12 +1,11 @@
|
||||||
name: Build and test cpu
|
name: Build and test cpu
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branch: 'master'
|
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- '**.md'
|
- '**.md'
|
||||||
- 'LICENSE'
|
- 'LICENSE'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths-ignore:
|
paths:
|
||||||
- '**.md'
|
- '**.md'
|
||||||
- 'LICENSE'
|
- 'LICENSE'
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
name: clang-format Check
|
name: clang-format Check
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branch: 'master'
|
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- '**.md'
|
- '**.md'
|
||||||
- 'LICENSE'
|
- 'LICENSE'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths-ignore:
|
paths:
|
||||||
- '**.md'
|
- '**.md'
|
||||||
- 'LICENSE'
|
- 'LICENSE'
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ requires-python = ">=3.7"
|
||||||
keywords = ["optimizer"]
|
keywords = ["optimizer"]
|
||||||
license = { text = "Apache" }
|
license = { text = "Apache" }
|
||||||
classifiers = ["Programming Language :: Python :: 3"]
|
classifiers = ["Programming Language :: Python :: 3"]
|
||||||
dependencies = ["onnx"]
|
dependencies = ["onnx","onnx-simplifier"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
|
@ -20,11 +20,14 @@ from onnx.checker import (
|
||||||
check_node,
|
check_node,
|
||||||
check_value_info,
|
check_value_info,
|
||||||
check_tensor,
|
check_tensor,
|
||||||
|
ValidationError,
|
||||||
)
|
)
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from onnx.numpy_helper import to_array
|
from onnx.numpy_helper import to_array
|
||||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from onnxsim import simplify
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
class OnnxStub:
|
class OnnxStub:
|
||||||
|
@ -33,6 +36,14 @@ class OnnxStub:
|
||||||
It can be generated from an Onnx model object.
|
It can be generated from an Onnx model object.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model: ModelProto, runtime):
|
def __init__(self, model: ModelProto, runtime):
|
||||||
|
# We use some user-defined operators for distributed inference
|
||||||
|
try:
|
||||||
|
# onnx simplifier performs inplace simplify
|
||||||
|
model_simp, check = simplify(copy.deepcopy(model))
|
||||||
|
if check:
|
||||||
|
model = model_simp
|
||||||
|
except ValidationError:
|
||||||
|
pass
|
||||||
self.inputs: Dict[str, backend.Tensor] = {}
|
self.inputs: Dict[str, backend.Tensor] = {}
|
||||||
self.outputs: Dict[str, backend.Tensor] = {}
|
self.outputs: Dict[str, backend.Tensor] = {}
|
||||||
self.initializer: Dict[int, TensorProto] = {}
|
self.initializer: Dict[int, TensorProto] = {}
|
||||||
|
|
Loading…
Reference in New Issue