diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ffb61529..1f6afe19 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -572,6 +572,24 @@ class OnnxStub: return ctx.build(name) + def init(self) -> None: + self.handler.data_malloc() + + def run(self) -> None: + self.handler.run() + + def put_int32(self, name: str) -> None: + self.handler.copy_int32(self.inputs[name]) + + def put_int64(self, name: str) -> None: + self.handler.copy_int64(self.inputs[name]) + + def put_float(self, name: str) -> None: + self.handler.copy_float(self.inputs[name]) + + def take_float(self) -> List[float]: + return next(self.handler.outputs.values()).copyFloats() + def from_onnx(model: ModelProto, runtime): stub = OnnxStub(model, runtime)