forked from jiuyuan/InfiniTensor
feat: 导出 cuda_runtime
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
60c5d6b5b8
commit
9ab78f13f7
|
@ -25,10 +25,10 @@ from onnx.shape_inference import infer_shapes
|
|||
from typing import Dict, List, Any, Tuple, Sequence, Union
|
||||
from functools import reduce
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
cpu_runtime = backend.cpu_runtime()
|
||||
|
||||
|
||||
def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
||||
def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
|
||||
|
@ -94,6 +95,10 @@ static int tensor_dtype(Tensor t) {
|
|||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||
|
@ -118,6 +123,9 @@ static Shape reshape_shape_of(Operator op) {
|
|||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
#ifdef USE_CUDA
|
||||
.FUNCTION(cuda_runtime)
|
||||
#endif
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
|
@ -132,6 +140,10 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
#ifdef USE_CUDA
|
||||
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
||||
m, "CudaRuntime");
|
||||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
|
|
Loading…
Reference in New Issue