forked from jiuyuan/InfiniTensor
parent
62be816f53
commit
877db21021
|
@ -60,7 +60,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
def shard_concat(node: NodeProto):
|
||||
# hack for kvcache
|
||||
in_plc = place[node.input[1]]
|
||||
if in_plc.is_sharded():
|
||||
if in_plc.is_shard():
|
||||
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
|
||||
seq_len_dim.dim_value //= tp_world_size
|
||||
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)
|
||||
|
|
|
@ -46,6 +46,7 @@ class OnnxStub:
|
|||
dims = [d for d in initializer.dims]
|
||||
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
||||
data[initializer.name] = initializer
|
||||
tensors[initializer.name].set_weight()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
|
@ -53,12 +54,14 @@ class OnnxStub:
|
|||
tensors[input.name] = self.handler.tensor(
|
||||
dims, input.type.tensor_type.elem_type
|
||||
)
|
||||
tensors[input.name].set_input()
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = self.handler.tensor(
|
||||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
tensors[output.name].set_output()
|
||||
|
||||
|
||||
node_name = []
|
||||
|
@ -666,19 +669,6 @@ class OnnxStub:
|
|||
# update the node_list
|
||||
node_list = list(set(node_name) - set(new_node_name))
|
||||
|
||||
################################
|
||||
# Set tensor type
|
||||
################################
|
||||
for initializer in model.graph.initializer:
|
||||
tensors[initializer.name].set_weight()
|
||||
|
||||
for input in model.graph.input:
|
||||
tensors[input.name].set_input()
|
||||
|
||||
for output in model.graph.output:
|
||||
tensors[output.name].set_output()
|
||||
|
||||
|
||||
################################
|
||||
# Allocate memory space for data
|
||||
################################
|
||||
|
|
Loading…
Reference in New Issue