feat: 导出 cuda_runtime

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-23 16:29:20 +08:00
parent 60c5d6b5b8
commit 9ab78f13f7
2 changed files with 14 additions and 2 deletions

View File

@ -25,10 +25,10 @@ from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple, Sequence, Union
from functools import reduce
runtime = backend.cpu_runtime()
cpu_runtime = backend.cpu_runtime()
def from_onnx(model: ModelProto) -> backend.GraphHandler:
def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
model = infer_shapes(model)
handler = backend.GraphHandler(runtime)

View File

@ -6,6 +6,7 @@
#include <pybind11/stl.h>
#ifdef USE_CUDA
#include "cuda/cuda_runtime.h"
#include "cuda/operator_timer.h"
#endif
@ -94,6 +95,10 @@ static int tensor_dtype(Tensor t) {
IT_ASSERT(false, "Unsupported data type");
}
#ifdef USE_CUDA
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
#endif
static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat);
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
@ -118,6 +123,9 @@ static Shape reshape_shape_of(Operator op) {
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
#ifdef USE_CUDA
.FUNCTION(cuda_runtime)
#endif
.FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of)
@ -132,6 +140,10 @@ void init_graph_builder(py::module &m) {
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime");
#ifdef USE_CUDA
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
m, "CudaRuntime");
#endif
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("shape", &TensorObj::getDims, policy::move)
.def("src", &TensorObj::getOutputOf, policy::move);