diff --git a/examples/distributed/bang_launch.py b/examples/distributed/bang_launch.py index 518935b5..18f366e6 100644 --- a/examples/distributed/bang_launch.py +++ b/examples/distributed/bang_launch.py @@ -30,6 +30,12 @@ def parse_args(): action="store_true", help="whether to generate the standard results.", ) + parser.add_argument( + "--run_torch", + default=False, + action="store_true", + help="whether to run model using PyTorch.", + ) args = parser.parse_args() print("arg setting: ", args) return ( @@ -40,6 +46,7 @@ def parse_args(): args.batch_size, args.length, args.gen_std, + args.run_torch ) @@ -153,12 +160,49 @@ def getDiff(base, test): f"Max relative difference: {max_relative_diff}") return max_absolute_diff, max_relative_diff +def run_pytorch(model, n=10): + from onnxsim import simplify + from onnx import version_converter + from onnx2torch import convert + import torch + import torch_mlu + + model, check = simplify(model) + target_version = 16 + converted_model = version_converter.convert_version(model, target_version) + device = torch.device("mlu") + torch_model = convert(converted_model) + torch_model.to(device) + torch_model.eval() + print("Run model using pytorch.") + + runtime = backend.BangRuntime(0) + stub = OnnxStub(model, runtime) + inputs = [] + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = torch.from_numpy(np.load(f"./data/input_{i}.npy")) + input.to(device) + inputs.append(input) + + torch.mlu.synchronize() + begin = time.time() + with torch.no_grad(): + for _ in range(n): + outputs = torch_model(*inputs) + torch.mlu.synchronize() + end = time.time() + + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + def main(): - nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() + nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_torch = parse_args() model = onnx.load(model_path) + # generate standart output if gen_std: print("Generate inputs and outputs.") @@ -190,6 +234,9 @@ def main(): for w in workers: w.join() + + if run_torch: + run_pytorch(model) if __name__ == "__main__":