diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ef7d4de2..f815c037 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -28,6 +28,10 @@ from functools import reduce cpu_runtime = backend.cpu_runtime() +def cuda_runtime(): + return backend.cuda_runtime() + + def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: model = infer_shapes(model) handler = backend.GraphHandler(runtime)