From 40fb8390b12d4d77b55345c85175017ee2775182 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 16:49:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=BC=E5=85=A5=E6=97=B6=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E6=9D=83=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/tensor.h | 2 +- pyinfinitensor/src/pyinfinitensor/onnx.py | 2 ++ src/ffi/ffi_infinitensor.cc | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/include/core/tensor.h b/include/core/tensor.h index 1edc950a..c5823be0 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -32,7 +32,7 @@ class TensorObj : public TensorBaseObj { using TensorBaseObj::getData; VType getData(const Shape &pos) const; void dataMalloc(); - UidBaseType getFuid() const { return fuid; } + inline UidBaseType getFuid() const { return fuid; } void load(std::string file_path); void save(std::string file_path); diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 28b8f514..d554883c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -35,6 +35,7 @@ def cuda_runtime(): class OnnxStub: inputs: Dict[str, backend.Tensor] = {} outputs: Dict[str, backend.Tensor] = {} + initializer: Dict[int, TensorProto] = {} handler: backend.GraphHandler def __init__(self, model: ModelProto, runtime): @@ -348,6 +349,7 @@ class OnnxStub: if any(input.name == name for input in model.graph.input): self.inputs[name] = obj else: + self.initializer[obj.fuid()] = tensor if tensor.data_type == TensorProto.INT32: self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) elif tensor.data_type == TensorProto.INT64: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 87599f28..d7230f42 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -154,6 +154,7 @@ void init_graph_builder(py::module &m) { m, "CudaRuntime"); #endif py::class_>(m, "Tensor") + .def("fuid", &TensorObj::getFuid, policy::automatic) .def("shape", &TensorObj::getDims, policy::move) .def("cloneFloats", &TensorObj::cloneFloats, policy::move) .def("has_target", &TensorObj::hasTarget, policy::automatic)