fix(front): 将stub改为可以接收GraphProto作为输入,消除分布式脚本保存额外的onnx文件, 采用int64作为index输入类型

This commit is contained in:
panzezhong 2024-04-07 17:15:40 +08:00
parent d1de3ab5c2
commit 6c4dd7b28b
4 changed files with 50 additions and 38 deletions

View File

@ -80,16 +80,18 @@ def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
):
dist_name = name + "_dist"
model = parallel_model(model, world_size, rank)
extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
onnx.save_model(
model,
f"./{dist_name}_rank{rank}.onnx",
save_as_external_data=True,
location=extern_path,
)
#model = parallel_model(model, world_size, rank)
graph = parallel_model(model, world_size, rank)
# extern_path = f"./{dist_name}_rank{rank}.pb"
# if os.path.exists(extern_path):
# os.remove(extern_path)
# onnx.save_model(
# model,
# f"./{dist_name}_rank{rank}.onnx",
# save_as_external_data=True,
# location=extern_path,
# )
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank)
# print("init comm")
@ -98,7 +100,8 @@ def start_worker(
world_size,
rank,
)
run_and_compare(name, model, runtime, data_type)
#run_and_compare(name, model, runtime, data_type)
run_and_compare(name, graph, runtime, data_type)
def start_single(name, model, data_type):

View File

@ -243,6 +243,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
tt = output.type.tensor_type
if tt.HasField("shape"):
tt.ClearField("shape")
model = helper.make_model(graph)
#model = helper.make_model(graph)
#model = onnx.shape_inference.infer_shapes(model)
return model
#return model
return graph

View File

@ -57,7 +57,7 @@ def get_model(modelname):
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int64)
np.save("test_inputs", data)
inputs = torch.from_numpy(data).to("cuda")
torch_model = torch_model.to("cuda")
@ -177,7 +177,7 @@ def main():
if export_path is not None:
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
path = os.path.join(export_path, filename)
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
param = torch.zeros((batchsize, seqlen), dtype=torch.int64)
export_onnx(model, param, path, True)
if data_type == "fp16":

View File

@ -6,6 +6,7 @@ from onnx import (
AttributeProto,
TensorShapeProto,
ValueInfoProto,
GraphProto,
)
from onnx.helper import (
make_node,
@ -40,22 +41,27 @@ class OnnxStub:
def __init__(
self,
model: ModelProto,
model: ModelProto | GraphProto,
runtime,
use_naive_allocator: bool = False,
matmul_compute_type: str = "default",
):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
model_simp, check = simplify(copy.deepcopy(model))
if check:
model = model_simp
except ValidationError:
pass
except RuntimeError:
pass
if isinstance(model, ModelProto):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
model_simp, check = simplify(copy.deepcopy(model))
if check:
model = model_simp
except ValidationError:
pass
except RuntimeError:
pass
graph = model.graph
elif isinstance(model, GraphProto):
graph = model
else:
raise RuntimeError("Expect onnx ModelProto or GraphProto as input.")
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.tensors: Dict[str, backend.Tensor] = {}
@ -67,10 +73,12 @@ class OnnxStub:
# except:
# warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime)
self.init_graph(graph, matmul_compute_type)
def init_graph(self, graph: GraphProto, matmul_compute_type: str):
# 处理重名和匿名算子
names = {}
for node in model.graph.node:
for node in graph.node:
if node.name == "":
node.name = "missing_name(" + node.op_type + ")"
if node.name in names:
@ -80,11 +88,11 @@ class OnnxStub:
names[node.name] = 0
# 拓扑排序
sorted_nodes = []
known_edge = set(t.name for t in model.graph.input)
known_edge.update(t.name for t in model.graph.initializer)
while len(sorted_nodes) < len(model.graph.node):
known_edge = set(t.name for t in graph.input)
known_edge.update(t.name for t in graph.initializer)
while len(sorted_nodes) < len(graph.node):
updated = False
for i, node in enumerate(model.graph.node):
for i, node in enumerate(graph.node):
if all(t in known_edge for t in node.input):
node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i)
@ -98,13 +106,13 @@ class OnnxStub:
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
for initializer in model.graph.initializer:
for initializer in graph.initializer:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
tensors[initializer.name].set_weight()
for input in model.graph.input:
for input in graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
if input.name not in tensors.keys():
tensors[input.name] = self.handler.tensor(
@ -114,7 +122,7 @@ class OnnxStub:
for node_idx in sorted_nodes:
node = model.graph.node[node_idx]
node = graph.node[node_idx]
if node.op_type == "Conv":
attributes = _parse_attribute(
node,
@ -994,7 +1002,7 @@ class OnnxStub:
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
for output in model.graph.output:
for output in graph.output:
tensors[output.name].set_output()
################################
# Allocate memory space for data
@ -1007,7 +1015,7 @@ class OnnxStub:
for name, obj in tensors.items():
tensor = data.get(name)
if tensor == None:
if any(input.name == name for input in model.graph.input):
if any(input.name == name for input in graph.input):
self.inputs[name] = obj
else:
self.initializer[obj.fuid()] = tensor
@ -1033,7 +1041,7 @@ class OnnxStub:
for name, obj in tensors.items():
self.tensors[name] = obj
for output in model.graph.output:
for output in graph.output:
self.outputs[output.name] = tensors[output.name]
def to_onnx(self, name: str) -> ModelProto: