diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c595a3b6..1f6b2758 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,12 +1,11 @@ name: Build and test cpu on: push: - branch: 'master' paths-ignore: - '**.md' - 'LICENSE' pull_request: - paths-ignore: + paths: - '**.md' - 'LICENSE' diff --git a/.github/workflows/clang-format-check.yml b/.github/workflows/clang-format-check.yml index dfd15b07..257da404 100644 --- a/.github/workflows/clang-format-check.yml +++ b/.github/workflows/clang-format-check.yml @@ -1,12 +1,11 @@ name: clang-format Check on: push: - branch: 'master' paths-ignore: - '**.md' - 'LICENSE' pull_request: - paths-ignore: + paths: - '**.md' - 'LICENSE' diff --git a/pyinfinitensor/pyproject.toml b/pyinfinitensor/pyproject.toml index 8c438812..5b09e91e 100644 --- a/pyinfinitensor/pyproject.toml +++ b/pyinfinitensor/pyproject.toml @@ -12,7 +12,7 @@ requires-python = ">=3.7" keywords = ["optimizer"] license = { text = "Apache" } classifiers = ["Programming Language :: Python :: 3"] -dependencies = ["onnx"] +dependencies = ["onnx","onnx-simplifier"] [tool.setuptools.packages.find] where = ["src"] diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index af1e1f95..103af8e4 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -20,11 +20,14 @@ from onnx.checker import ( check_node, check_value_info, check_tensor, + ValidationError, ) from onnx.shape_inference import infer_shapes from onnx.numpy_helper import to_array from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from functools import reduce +from onnxsim import simplify +import copy class OnnxStub: @@ -33,6 +36,14 @@ class OnnxStub: It can be generated from an Onnx model object. """ def __init__(self, model: ModelProto, runtime): + # We use some user-defined operators for distributed inference + try: + # onnx simplifier performs inplace simplify + model_simp, check = simplify(copy.deepcopy(model)) + if check: + model = model_simp + except ValidationError: + pass self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {}