forked from jiuyuan/InfiniTensor
feat: 导入时保存权重
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
a5e692baea
commit
40fb8390b1
|
@ -32,7 +32,7 @@ class TensorObj : public TensorBaseObj {
|
||||||
using TensorBaseObj::getData;
|
using TensorBaseObj::getData;
|
||||||
VType getData(const Shape &pos) const;
|
VType getData(const Shape &pos) const;
|
||||||
void dataMalloc();
|
void dataMalloc();
|
||||||
UidBaseType getFuid() const { return fuid; }
|
inline UidBaseType getFuid() const { return fuid; }
|
||||||
|
|
||||||
void load(std::string file_path);
|
void load(std::string file_path);
|
||||||
void save(std::string file_path);
|
void save(std::string file_path);
|
||||||
|
|
|
@ -35,6 +35,7 @@ def cuda_runtime():
|
||||||
class OnnxStub:
|
class OnnxStub:
|
||||||
inputs: Dict[str, backend.Tensor] = {}
|
inputs: Dict[str, backend.Tensor] = {}
|
||||||
outputs: Dict[str, backend.Tensor] = {}
|
outputs: Dict[str, backend.Tensor] = {}
|
||||||
|
initializer: Dict[int, TensorProto] = {}
|
||||||
handler: backend.GraphHandler
|
handler: backend.GraphHandler
|
||||||
|
|
||||||
def __init__(self, model: ModelProto, runtime):
|
def __init__(self, model: ModelProto, runtime):
|
||||||
|
@ -348,6 +349,7 @@ class OnnxStub:
|
||||||
if any(input.name == name for input in model.graph.input):
|
if any(input.name == name for input in model.graph.input):
|
||||||
self.inputs[name] = obj
|
self.inputs[name] = obj
|
||||||
else:
|
else:
|
||||||
|
self.initializer[obj.fuid()] = tensor
|
||||||
if tensor.data_type == TensorProto.INT32:
|
if tensor.data_type == TensorProto.INT32:
|
||||||
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||||
elif tensor.data_type == TensorProto.INT64:
|
elif tensor.data_type == TensorProto.INT64:
|
||||||
|
|
|
@ -154,6 +154,7 @@ void init_graph_builder(py::module &m) {
|
||||||
m, "CudaRuntime");
|
m, "CudaRuntime");
|
||||||
#endif
|
#endif
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||||
|
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||||
.def("shape", &TensorObj::getDims, policy::move)
|
.def("shape", &TensorObj::getDims, policy::move)
|
||||||
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
|
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
|
||||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||||
|
|
Loading…
Reference in New Issue