forked from jiuyuan/InfiniTensor
parent
62be816f53
commit
877db21021
|
@ -60,7 +60,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
def shard_concat(node: NodeProto):
|
def shard_concat(node: NodeProto):
|
||||||
# hack for kvcache
|
# hack for kvcache
|
||||||
in_plc = place[node.input[1]]
|
in_plc = place[node.input[1]]
|
||||||
if in_plc.is_sharded():
|
if in_plc.is_shard():
|
||||||
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
|
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
|
||||||
seq_len_dim.dim_value //= tp_world_size
|
seq_len_dim.dim_value //= tp_world_size
|
||||||
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)
|
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)
|
||||||
|
|
|
@ -46,6 +46,7 @@ class OnnxStub:
|
||||||
dims = [d for d in initializer.dims]
|
dims = [d for d in initializer.dims]
|
||||||
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
||||||
data[initializer.name] = initializer
|
data[initializer.name] = initializer
|
||||||
|
tensors[initializer.name].set_weight()
|
||||||
|
|
||||||
for input in model.graph.input:
|
for input in model.graph.input:
|
||||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||||
|
@ -53,12 +54,14 @@ class OnnxStub:
|
||||||
tensors[input.name] = self.handler.tensor(
|
tensors[input.name] = self.handler.tensor(
|
||||||
dims, input.type.tensor_type.elem_type
|
dims, input.type.tensor_type.elem_type
|
||||||
)
|
)
|
||||||
|
tensors[input.name].set_input()
|
||||||
|
|
||||||
for output in model.graph.output:
|
for output in model.graph.output:
|
||||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||||
tensors[output.name] = self.handler.tensor(
|
tensors[output.name] = self.handler.tensor(
|
||||||
dims, output.type.tensor_type.elem_type
|
dims, output.type.tensor_type.elem_type
|
||||||
)
|
)
|
||||||
|
tensors[output.name].set_output()
|
||||||
|
|
||||||
|
|
||||||
node_name = []
|
node_name = []
|
||||||
|
@ -666,19 +669,6 @@ class OnnxStub:
|
||||||
# update the node_list
|
# update the node_list
|
||||||
node_list = list(set(node_name) - set(new_node_name))
|
node_list = list(set(node_name) - set(new_node_name))
|
||||||
|
|
||||||
################################
|
|
||||||
# Set tensor type
|
|
||||||
################################
|
|
||||||
for initializer in model.graph.initializer:
|
|
||||||
tensors[initializer.name].set_weight()
|
|
||||||
|
|
||||||
for input in model.graph.input:
|
|
||||||
tensors[input.name].set_input()
|
|
||||||
|
|
||||||
for output in model.graph.output:
|
|
||||||
tensors[output.name].set_output()
|
|
||||||
|
|
||||||
|
|
||||||
################################
|
################################
|
||||||
# Allocate memory space for data
|
# Allocate memory space for data
|
||||||
################################
|
################################
|
||||||
|
|
Loading…
Reference in New Issue