add onnx simplify (#153)

* add onnx simplify

* fix test bug

* update ci policy

* fix onnx simpilfy bug

* update ci workflow
This commit is contained in:
Haojie Wang 2023-10-10 15:45:27 +08:00 committed by GitHub
parent 7f16fa353e
commit 90b9a80f72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 5 deletions

View File

@ -1,12 +1,11 @@
name: Build and test cpu name: Build and test cpu
on: on:
push: push:
branch: 'master'
paths-ignore: paths-ignore:
- '**.md' - '**.md'
- 'LICENSE' - 'LICENSE'
pull_request: pull_request:
paths-ignore: paths:
- '**.md' - '**.md'
- 'LICENSE' - 'LICENSE'

View File

@ -1,12 +1,11 @@
name: clang-format Check name: clang-format Check
on: on:
push: push:
branch: 'master'
paths-ignore: paths-ignore:
- '**.md' - '**.md'
- 'LICENSE' - 'LICENSE'
pull_request: pull_request:
paths-ignore: paths:
- '**.md' - '**.md'
- 'LICENSE' - 'LICENSE'

View File

@ -12,7 +12,7 @@ requires-python = ">=3.7"
keywords = ["optimizer"] keywords = ["optimizer"]
license = { text = "Apache" } license = { text = "Apache" }
classifiers = ["Programming Language :: Python :: 3"] classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["onnx"] dependencies = ["onnx","onnx-simplifier"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]

View File

@ -20,11 +20,14 @@ from onnx.checker import (
check_node, check_node,
check_value_info, check_value_info,
check_tensor, check_tensor,
ValidationError,
) )
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array from onnx.numpy_helper import to_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce from functools import reduce
from onnxsim import simplify
import copy
class OnnxStub: class OnnxStub:
@ -33,6 +36,14 @@ class OnnxStub:
It can be generated from an Onnx model object. It can be generated from an Onnx model object.
""" """
def __init__(self, model: ModelProto, runtime): def __init__(self, model: ModelProto, runtime):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
model_simp, check = simplify(copy.deepcopy(model))
if check:
model = model_simp
except ValidationError:
pass
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {} self.initializer: Dict[int, TensorProto] = {}