forked from jiuyuan/InfiniTensor
feat: 导入时保存权重
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
a5e692baea
commit
40fb8390b1
|
@ -32,7 +32,7 @@ class TensorObj : public TensorBaseObj {
|
|||
using TensorBaseObj::getData;
|
||||
VType getData(const Shape &pos) const;
|
||||
void dataMalloc();
|
||||
UidBaseType getFuid() const { return fuid; }
|
||||
inline UidBaseType getFuid() const { return fuid; }
|
||||
|
||||
void load(std::string file_path);
|
||||
void save(std::string file_path);
|
||||
|
|
|
@ -35,6 +35,7 @@ def cuda_runtime():
|
|||
class OnnxStub:
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
outputs: Dict[str, backend.Tensor] = {}
|
||||
initializer: Dict[int, TensorProto] = {}
|
||||
handler: backend.GraphHandler
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
|
@ -348,6 +349,7 @@ class OnnxStub:
|
|||
if any(input.name == name for input in model.graph.input):
|
||||
self.inputs[name] = obj
|
||||
else:
|
||||
self.initializer[obj.fuid()] = tensor
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
|
|
|
@ -154,6 +154,7 @@ void init_graph_builder(py::module &m) {
|
|||
m, "CudaRuntime");
|
||||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue