diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 207280ef..ef7d4de2 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -25,10 +25,10 @@ from onnx.shape_inference import infer_shapes from typing import Dict, List, Any, Tuple, Sequence, Union from functools import reduce -runtime = backend.cpu_runtime() +cpu_runtime = backend.cpu_runtime() -def from_onnx(model: ModelProto) -> backend.GraphHandler: +def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: model = infer_shapes(model) handler = backend.GraphHandler(runtime) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index b791e66e..5d6085ea 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -6,6 +6,7 @@ #include #ifdef USE_CUDA +#include "cuda/cuda_runtime.h" #include "cuda/operator_timer.h" #endif @@ -94,6 +95,10 @@ static int tensor_dtype(Tensor t) { IT_ASSERT(false, "Unsupported data type"); } +#ifdef USE_CUDA +static Ref cuda_runtime() { return make_ref(); } +#endif + static int concat_axis_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Concat); return dynamic_cast(op.get())->getDim(); @@ -118,6 +123,9 @@ static Shape reshape_shape_of(Operator op) { void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &CpuRuntimeObj::getInstance) +#ifdef USE_CUDA + .FUNCTION(cuda_runtime) +#endif .FUNCTION(tensor_dtype) .FUNCTION(reshape_shape_of) .FUNCTION(concat_axis_of) @@ -132,6 +140,10 @@ void init_graph_builder(py::module &m) { py::class_>(m, "Runtime"); py::class_, RuntimeObj>( m, "CpuRuntime"); +#ifdef USE_CUDA + py::class_, RuntimeObj>( + m, "CudaRuntime"); +#endif py::class_>(m, "Tensor") .def("shape", &TensorObj::getDims, policy::move) .def("src", &TensorObj::getOutputOf, policy::move);