forked from jiuyuan/InfiniTensor
Compare commits
1 Commits
master
...
dist/graph
Author | SHA1 | Date |
---|---|---|
panzezhong | 6c4dd7b28b |
|
@ -80,16 +80,18 @@ def start_worker(
|
|||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
onnx.save_model(
|
||||
model,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
#model = parallel_model(model, world_size, rank)
|
||||
graph = parallel_model(model, world_size, rank)
|
||||
|
||||
# extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
# if os.path.exists(extern_path):
|
||||
# os.remove(extern_path)
|
||||
# onnx.save_model(
|
||||
# model,
|
||||
# f"./{dist_name}_rank{rank}.onnx",
|
||||
# save_as_external_data=True,
|
||||
# location=extern_path,
|
||||
# )
|
||||
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
# print("init comm")
|
||||
|
@ -98,7 +100,8 @@ def start_worker(
|
|||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime, data_type)
|
||||
#run_and_compare(name, model, runtime, data_type)
|
||||
run_and_compare(name, graph, runtime, data_type)
|
||||
|
||||
|
||||
def start_single(name, model, data_type):
|
||||
|
|
|
@ -243,6 +243,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
tt = output.type.tensor_type
|
||||
if tt.HasField("shape"):
|
||||
tt.ClearField("shape")
|
||||
model = helper.make_model(graph)
|
||||
#model = helper.make_model(graph)
|
||||
#model = onnx.shape_inference.infer_shapes(model)
|
||||
return model
|
||||
#return model
|
||||
return graph
|
||||
|
|
|
@ -57,7 +57,7 @@ def get_model(modelname):
|
|||
return model, voc_size
|
||||
|
||||
def run_pytorch(torch_model, voc_size, batchsize, len):
|
||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int64)
|
||||
np.save("test_inputs", data)
|
||||
inputs = torch.from_numpy(data).to("cuda")
|
||||
torch_model = torch_model.to("cuda")
|
||||
|
@ -177,7 +177,7 @@ def main():
|
|||
if export_path is not None:
|
||||
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
|
||||
path = os.path.join(export_path, filename)
|
||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int64)
|
||||
export_onnx(model, param, path, True)
|
||||
|
||||
if data_type == "fp16":
|
||||
|
|
|
@ -6,6 +6,7 @@ from onnx import (
|
|||
AttributeProto,
|
||||
TensorShapeProto,
|
||||
ValueInfoProto,
|
||||
GraphProto,
|
||||
)
|
||||
from onnx.helper import (
|
||||
make_node,
|
||||
|
@ -40,22 +41,27 @@ class OnnxStub:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: ModelProto,
|
||||
model: ModelProto | GraphProto,
|
||||
runtime,
|
||||
use_naive_allocator: bool = False,
|
||||
matmul_compute_type: str = "default",
|
||||
):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
model_simp, check = simplify(copy.deepcopy(model))
|
||||
if check:
|
||||
model = model_simp
|
||||
except ValidationError:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if isinstance(model, ModelProto):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
model_simp, check = simplify(copy.deepcopy(model))
|
||||
if check:
|
||||
model = model_simp
|
||||
except ValidationError:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass
|
||||
graph = model.graph
|
||||
elif isinstance(model, GraphProto):
|
||||
graph = model
|
||||
else:
|
||||
raise RuntimeError("Expect onnx ModelProto or GraphProto as input.")
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.tensors: Dict[str, backend.Tensor] = {}
|
||||
|
@ -67,10 +73,12 @@ class OnnxStub:
|
|||
# except:
|
||||
# warnings.warn("infer_shapes failed.")
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
self.init_graph(graph, matmul_compute_type)
|
||||
|
||||
def init_graph(self, graph: GraphProto, matmul_compute_type: str):
|
||||
# 处理重名和匿名算子
|
||||
names = {}
|
||||
for node in model.graph.node:
|
||||
for node in graph.node:
|
||||
if node.name == "":
|
||||
node.name = "missing_name(" + node.op_type + ")"
|
||||
if node.name in names:
|
||||
|
@ -80,11 +88,11 @@ class OnnxStub:
|
|||
names[node.name] = 0
|
||||
# 拓扑排序
|
||||
sorted_nodes = []
|
||||
known_edge = set(t.name for t in model.graph.input)
|
||||
known_edge.update(t.name for t in model.graph.initializer)
|
||||
while len(sorted_nodes) < len(model.graph.node):
|
||||
known_edge = set(t.name for t in graph.input)
|
||||
known_edge.update(t.name for t in graph.initializer)
|
||||
while len(sorted_nodes) < len(graph.node):
|
||||
updated = False
|
||||
for i, node in enumerate(model.graph.node):
|
||||
for i, node in enumerate(graph.node):
|
||||
if all(t in known_edge for t in node.input):
|
||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||
sorted_nodes.append(i)
|
||||
|
@ -98,13 +106,13 @@ class OnnxStub:
|
|||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
for initializer in graph.initializer:
|
||||
dims = [d for d in initializer.dims]
|
||||
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
||||
data[initializer.name] = initializer
|
||||
tensors[initializer.name].set_weight()
|
||||
|
||||
for input in model.graph.input:
|
||||
for input in graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
if input.name not in tensors.keys():
|
||||
tensors[input.name] = self.handler.tensor(
|
||||
|
@ -114,7 +122,7 @@ class OnnxStub:
|
|||
|
||||
|
||||
for node_idx in sorted_nodes:
|
||||
node = model.graph.node[node_idx]
|
||||
node = graph.node[node_idx]
|
||||
if node.op_type == "Conv":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
@ -994,7 +1002,7 @@ class OnnxStub:
|
|||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
for output in model.graph.output:
|
||||
for output in graph.output:
|
||||
tensors[output.name].set_output()
|
||||
################################
|
||||
# Allocate memory space for data
|
||||
|
@ -1007,7 +1015,7 @@ class OnnxStub:
|
|||
for name, obj in tensors.items():
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
if any(input.name == name for input in graph.input):
|
||||
self.inputs[name] = obj
|
||||
else:
|
||||
self.initializer[obj.fuid()] = tensor
|
||||
|
@ -1033,7 +1041,7 @@ class OnnxStub:
|
|||
for name, obj in tensors.items():
|
||||
self.tensors[name] = obj
|
||||
|
||||
for output in model.graph.output:
|
||||
for output in graph.output:
|
||||
self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
def to_onnx(self, name: str) -> ModelProto:
|
||||
|
|
Loading…
Reference in New Issue