diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index b16386a7..42465a69 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -60,7 +60,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): def shard_concat(node: NodeProto): # hack for kvcache in_plc = place[node.input[1]] - if in_plc.is_sharded(): + if in_plc.is_shard(): seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1) seq_len_dim.dim_value //= tp_world_size vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 813a5e8e..121ad22e 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -46,6 +46,7 @@ class OnnxStub: dims = [d for d in initializer.dims] tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) data[initializer.name] = initializer + tensors[initializer.name].set_weight() for input in model.graph.input: dims = _take_shape_dim(input.type.tensor_type.shape) @@ -53,12 +54,14 @@ class OnnxStub: tensors[input.name] = self.handler.tensor( dims, input.type.tensor_type.elem_type ) + tensors[input.name].set_input() for output in model.graph.output: dims = _take_shape_dim(output.type.tensor_type.shape) tensors[output.name] = self.handler.tensor( dims, output.type.tensor_type.elem_type ) + tensors[output.name].set_output() node_name = [] @@ -666,19 +669,6 @@ class OnnxStub: # update the node_list node_list = list(set(node_name) - set(new_node_name)) - ################################ - # Set tensor type - ################################ - for initializer in model.graph.initializer: - tensors[initializer.name].set_weight() - - for input in model.graph.input: - tensors[input.name].set_input() - - for output in model.graph.output: - tensors[output.name].set_output() - - ################################ # Allocate memory space for data ################################