add onnx simplify (#153)

* add onnx simplify

* fix test bug

* update ci policy

* fix onnx simpilfy bug

* update ci workflow
This commit is contained in:
Haojie Wang 2023-10-10 15:45:27 +08:00 committed by GitHub
parent 7f16fa353e
commit 90b9a80f72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 5 deletions

View File

@ -1,12 +1,11 @@
name: Build and test cpu
on:
push:
branch: 'master'
paths-ignore:
- '**.md'
- 'LICENSE'
pull_request:
paths-ignore:
paths:
- '**.md'
- 'LICENSE'

View File

@ -1,12 +1,11 @@
name: clang-format Check
on:
push:
branch: 'master'
paths-ignore:
- '**.md'
- 'LICENSE'
pull_request:
paths-ignore:
paths:
- '**.md'
- 'LICENSE'

View File

@ -12,7 +12,7 @@ requires-python = ">=3.7"
keywords = ["optimizer"]
license = { text = "Apache" }
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["onnx"]
dependencies = ["onnx","onnx-simplifier"]
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -20,11 +20,14 @@ from onnx.checker import (
check_node,
check_value_info,
check_tensor,
ValidationError,
)
from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
from onnxsim import simplify
import copy
class OnnxStub:
@ -33,6 +36,14 @@ class OnnxStub:
It can be generated from an Onnx model object.
"""
def __init__(self, model: ModelProto, runtime):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
model_simp, check = simplify(copy.deepcopy(model))
if check:
model = model_simp
except ValidationError:
pass
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {}