Compare commits

...

1 Commits

4 changed files with 50 additions and 38 deletions

View File

@ -80,16 +80,18 @@ def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
): ):
dist_name = name + "_dist" dist_name = name + "_dist"
model = parallel_model(model, world_size, rank) #model = parallel_model(model, world_size, rank)
extern_path = f"./{dist_name}_rank{rank}.pb" graph = parallel_model(model, world_size, rank)
if os.path.exists(extern_path):
os.remove(extern_path) # extern_path = f"./{dist_name}_rank{rank}.pb"
onnx.save_model( # if os.path.exists(extern_path):
model, # os.remove(extern_path)
f"./{dist_name}_rank{rank}.onnx", # onnx.save_model(
save_as_external_data=True, # model,
location=extern_path, # f"./{dist_name}_rank{rank}.onnx",
) # save_as_external_data=True,
# location=extern_path,
# )
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") #infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank) runtime = backend.CudaRuntime(local_rank)
# print("init comm") # print("init comm")
@ -98,7 +100,8 @@ def start_worker(
world_size, world_size,
rank, rank,
) )
run_and_compare(name, model, runtime, data_type) #run_and_compare(name, model, runtime, data_type)
run_and_compare(name, graph, runtime, data_type)
def start_single(name, model, data_type): def start_single(name, model, data_type):

View File

@ -243,6 +243,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
tt = output.type.tensor_type tt = output.type.tensor_type
if tt.HasField("shape"): if tt.HasField("shape"):
tt.ClearField("shape") tt.ClearField("shape")
model = helper.make_model(graph) #model = helper.make_model(graph)
#model = onnx.shape_inference.infer_shapes(model) #model = onnx.shape_inference.infer_shapes(model)
return model #return model
return graph

View File

@ -57,7 +57,7 @@ def get_model(modelname):
return model, voc_size return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len): def run_pytorch(torch_model, voc_size, batchsize, len):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32) data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int64)
np.save("test_inputs", data) np.save("test_inputs", data)
inputs = torch.from_numpy(data).to("cuda") inputs = torch.from_numpy(data).to("cuda")
torch_model = torch_model.to("cuda") torch_model = torch_model.to("cuda")
@ -177,7 +177,7 @@ def main():
if export_path is not None: if export_path is not None:
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen) filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
path = os.path.join(export_path, filename) path = os.path.join(export_path, filename)
param = torch.zeros((batchsize, seqlen), dtype=torch.int) param = torch.zeros((batchsize, seqlen), dtype=torch.int64)
export_onnx(model, param, path, True) export_onnx(model, param, path, True)
if data_type == "fp16": if data_type == "fp16":

View File

@ -6,6 +6,7 @@ from onnx import (
AttributeProto, AttributeProto,
TensorShapeProto, TensorShapeProto,
ValueInfoProto, ValueInfoProto,
GraphProto,
) )
from onnx.helper import ( from onnx.helper import (
make_node, make_node,
@ -40,22 +41,27 @@ class OnnxStub:
def __init__( def __init__(
self, self,
model: ModelProto, model: ModelProto | GraphProto,
runtime, runtime,
use_naive_allocator: bool = False, use_naive_allocator: bool = False,
matmul_compute_type: str = "default", matmul_compute_type: str = "default",
): ):
# We use some user-defined operators for distributed inference if isinstance(model, ModelProto):
try: # We use some user-defined operators for distributed inference
# onnx simplifier performs inplace simplify try:
model_simp, check = simplify(copy.deepcopy(model)) # onnx simplifier performs inplace simplify
if check: model_simp, check = simplify(copy.deepcopy(model))
model = model_simp if check:
except ValidationError: model = model_simp
pass except ValidationError:
except RuntimeError: pass
pass except RuntimeError:
pass
graph = model.graph
elif isinstance(model, GraphProto):
graph = model
else:
raise RuntimeError("Expect onnx ModelProto or GraphProto as input.")
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
self.tensors: Dict[str, backend.Tensor] = {} self.tensors: Dict[str, backend.Tensor] = {}
@ -67,10 +73,12 @@ class OnnxStub:
# except: # except:
# warnings.warn("infer_shapes failed.") # warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime) self.handler = backend.GraphHandler(runtime)
self.init_graph(graph, matmul_compute_type)
def init_graph(self, graph: GraphProto, matmul_compute_type: str):
# 处理重名和匿名算子 # 处理重名和匿名算子
names = {} names = {}
for node in model.graph.node: for node in graph.node:
if node.name == "": if node.name == "":
node.name = "missing_name(" + node.op_type + ")" node.name = "missing_name(" + node.op_type + ")"
if node.name in names: if node.name in names:
@ -80,11 +88,11 @@ class OnnxStub:
names[node.name] = 0 names[node.name] = 0
# 拓扑排序 # 拓扑排序
sorted_nodes = [] sorted_nodes = []
known_edge = set(t.name for t in model.graph.input) known_edge = set(t.name for t in graph.input)
known_edge.update(t.name for t in model.graph.initializer) known_edge.update(t.name for t in graph.initializer)
while len(sorted_nodes) < len(model.graph.node): while len(sorted_nodes) < len(graph.node):
updated = False updated = False
for i, node in enumerate(model.graph.node): for i, node in enumerate(graph.node):
if all(t in known_edge for t in node.input): if all(t in known_edge for t in node.input):
node.name = str(len(sorted_nodes)) + "_" + node.name node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i) sorted_nodes.append(i)
@ -98,13 +106,13 @@ class OnnxStub:
tensors: Dict[str, backend.Tensor] = dict() tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict() data: Dict[str, TensorProto] = dict()
for initializer in model.graph.initializer: for initializer in graph.initializer:
dims = [d for d in initializer.dims] dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer data[initializer.name] = initializer
tensors[initializer.name].set_weight() tensors[initializer.name].set_weight()
for input in model.graph.input: for input in graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape) dims = _take_shape_dim(input.type.tensor_type.shape)
if input.name not in tensors.keys(): if input.name not in tensors.keys():
tensors[input.name] = self.handler.tensor( tensors[input.name] = self.handler.tensor(
@ -114,7 +122,7 @@ class OnnxStub:
for node_idx in sorted_nodes: for node_idx in sorted_nodes:
node = model.graph.node[node_idx] node = graph.node[node_idx]
if node.op_type == "Conv": if node.op_type == "Conv":
attributes = _parse_attribute( attributes = _parse_attribute(
node, node,
@ -994,7 +1002,7 @@ class OnnxStub:
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
for output in model.graph.output: for output in graph.output:
tensors[output.name].set_output() tensors[output.name].set_output()
################################ ################################
# Allocate memory space for data # Allocate memory space for data
@ -1007,7 +1015,7 @@ class OnnxStub:
for name, obj in tensors.items(): for name, obj in tensors.items():
tensor = data.get(name) tensor = data.get(name)
if tensor == None: if tensor == None:
if any(input.name == name for input in model.graph.input): if any(input.name == name for input in graph.input):
self.inputs[name] = obj self.inputs[name] = obj
else: else:
self.initializer[obj.fuid()] = tensor self.initializer[obj.fuid()] = tensor
@ -1033,7 +1041,7 @@ class OnnxStub:
for name, obj in tensors.items(): for name, obj in tensors.items():
self.tensors[name] = obj self.tensors[name] = obj
for output in model.graph.output: for output in graph.output:
self.outputs[output.name] = tensors[output.name] self.outputs[output.name] = tensors[output.name]
def to_onnx(self, name: str) -> ModelProto: def to_onnx(self, name: str) -> ModelProto: