Fix support kvcache (#142)

* - fix onnx.py

* - fix shard_concat
This commit is contained in:
kilinchange 2023-09-27 11:08:44 +08:00 committed by GitHub
parent 62be816f53
commit 877db21021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 14 deletions

View File

@ -60,7 +60,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
def shard_concat(node: NodeProto):
# hack for kvcache
in_plc = place[node.input[1]]
if in_plc.is_sharded():
if in_plc.is_shard():
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
seq_len_dim.dim_value //= tp_world_size
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)

View File

@ -46,6 +46,7 @@ class OnnxStub:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
tensors[initializer.name].set_weight()
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
@ -53,12 +54,14 @@ class OnnxStub:
tensors[input.name] = self.handler.tensor(
dims, input.type.tensor_type.elem_type
)
tensors[input.name].set_input()
for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape)
tensors[output.name] = self.handler.tensor(
dims, output.type.tensor_type.elem_type
)
tensors[output.name].set_output()
node_name = []
@ -666,19 +669,6 @@ class OnnxStub:
# update the node_list
node_list = list(set(node_name) - set(new_node_name))
################################
# Set tensor type
################################
for initializer in model.graph.initializer:
tensors[initializer.name].set_weight()
for input in model.graph.input:
tensors[input.name].set_input()
for output in model.graph.output:
tensors[output.name].set_output()
################################
# Allocate memory space for data
################################