feat: 导入时保存权重

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 16:49:53 +08:00
parent a5e692baea
commit 40fb8390b1
3 changed files with 4 additions and 1 deletions

View File

@ -32,7 +32,7 @@ class TensorObj : public TensorBaseObj {
using TensorBaseObj::getData; using TensorBaseObj::getData;
VType getData(const Shape &pos) const; VType getData(const Shape &pos) const;
void dataMalloc(); void dataMalloc();
UidBaseType getFuid() const { return fuid; } inline UidBaseType getFuid() const { return fuid; }
void load(std::string file_path); void load(std::string file_path);
void save(std::string file_path); void save(std::string file_path);

View File

@ -35,6 +35,7 @@ def cuda_runtime():
class OnnxStub: class OnnxStub:
inputs: Dict[str, backend.Tensor] = {} inputs: Dict[str, backend.Tensor] = {}
outputs: Dict[str, backend.Tensor] = {} outputs: Dict[str, backend.Tensor] = {}
initializer: Dict[int, TensorProto] = {}
handler: backend.GraphHandler handler: backend.GraphHandler
def __init__(self, model: ModelProto, runtime): def __init__(self, model: ModelProto, runtime):
@ -348,6 +349,7 @@ class OnnxStub:
if any(input.name == name for input in model.graph.input): if any(input.name == name for input in model.graph.input):
self.inputs[name] = obj self.inputs[name] = obj
else: else:
self.initializer[obj.fuid()] = tensor
if tensor.data_type == TensorProto.INT32: if tensor.data_type == TensorProto.INT32:
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
elif tensor.data_type == TensorProto.INT64: elif tensor.data_type == TensorProto.INT64:

View File

@ -154,6 +154,7 @@ void init_graph_builder(py::module &m) {
m, "CudaRuntime"); m, "CudaRuntime");
#endif #endif
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor") py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("fuid", &TensorObj::getFuid, policy::automatic)
.def("shape", &TensorObj::getDims, policy::move) .def("shape", &TensorObj::getDims, policy::move)
.def("cloneFloats", &TensorObj::cloneFloats, policy::move) .def("cloneFloats", &TensorObj::cloneFloats, policy::move)
.def("has_target", &TensorObj::hasTarget, policy::automatic) .def("has_target", &TensorObj::hasTarget, policy::automatic)