forked from jiuyuan/InfiniTensor
104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import onnx
|
|
from onnx import (
|
|
ModelProto,
|
|
TensorProto,
|
|
NodeProto,
|
|
AttributeProto,
|
|
)
|
|
from onnx import helper, numpy_helper
|
|
from typing import Dict, Any
|
|
|
|
|
|
def parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
|
for attr in node.attribute:
|
|
if attr.name in attrs:
|
|
if attr.type == AttributeProto.INT:
|
|
attrs[attr.name] = attr.i
|
|
elif attr.type == AttributeProto.INTS:
|
|
attrs[attr.name] = attr.ints
|
|
elif attr.type == AttributeProto.FLOAT:
|
|
attrs[attr.name] = attr.f
|
|
elif attr.type == AttributeProto.STRING:
|
|
attrs[attr.name] = attr.s
|
|
elif attr.type == AttributeProto.TENSOR:
|
|
attrs[attr.name] = attr.t
|
|
else:
|
|
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
|
return attrs
|
|
|
|
|
|
def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|
data = {init.name: init for init in model.graph.initializer}
|
|
nodes = list(model.graph.node)
|
|
|
|
def shard_tensor(tensor: TensorProto, dim: int):
|
|
array = numpy_helper.to_array(tensor)
|
|
if dim >= array.ndim:
|
|
dim = array.ndim - 1
|
|
assert array.shape[dim] % tp_world_size == 0
|
|
seg = array.shape[dim] // tp_world_size
|
|
array = array[tp_rank * seg : (tp_rank + 1) * seg]
|
|
return numpy_helper.from_array(array, name=tensor.name + f":sharded({dim})")
|
|
|
|
def shard_gemm(node: NodeProto):
|
|
attrs = parse_attribute(
|
|
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
|
)
|
|
trans = [attrs["transA"], attrs["transB"]]
|
|
dim = 0
|
|
for i, (input, t) in enumerate(zip(node.input, trans)):
|
|
if input in data:
|
|
dim = i
|
|
sharded = shard_tensor(data[input], dim ^ t)
|
|
node.input[i] = sharded.name
|
|
data[input] = sharded
|
|
if len(node.input) > 2:
|
|
input = node.input[2]
|
|
sharded = shard_tensor(data[input], dim)
|
|
node.input[2] = sharded.name
|
|
data[input] = sharded
|
|
|
|
node.output[0] += f":sharded({dim})"
|
|
return dim
|
|
|
|
for i, node in enumerate(nodes):
|
|
if node.op_type == "Gemm":
|
|
output = node.output[0]
|
|
dim = shard_gemm(node)
|
|
gathered = [node.output[0] + f".{i}" for i in range(tp_world_size)]
|
|
# all_gather
|
|
nodes.insert(
|
|
i + 1,
|
|
helper.make_node(
|
|
op_type="AllGather",
|
|
inputs=[node.output[0]],
|
|
outputs=gathered,
|
|
name=node.name + "/allgather",
|
|
# domain="infini", # shape inference fails for custom domain
|
|
),
|
|
)
|
|
# concat
|
|
nodes.insert(
|
|
i + 2,
|
|
helper.make_node(
|
|
op_type="Concat",
|
|
inputs=gathered,
|
|
outputs=[output],
|
|
name=node.name + "/concat",
|
|
axis=dim,
|
|
),
|
|
)
|
|
graph = helper.make_graph(
|
|
nodes,
|
|
model.graph.name + f"_{tp_rank}",
|
|
model.graph.input,
|
|
model.graph.output,
|
|
data.values(),
|
|
doc_string=model.graph.doc_string,
|
|
value_info=model.graph.value_info,
|
|
)
|
|
model = helper.make_model(graph)
|
|
|
|
onnx.shape_inference.infer_shapes(model)
|
|
return model
|