diff --git a/examples/distributed/cuda_launch.py b/examples/distributed/cuda_launch.py index 0f48598a..887a34ee 100644 --- a/examples/distributed/cuda_launch.py +++ b/examples/distributed/cuda_launch.py @@ -10,9 +10,6 @@ import numpy as np from parallel_opt import parallel_model -os.environ["NVIDIA_TF32_OVERRIDE"] = "0" - - def parse_args(): parser = argparse.ArgumentParser(description="launch distributed infinitensor") parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") @@ -32,6 +29,9 @@ def parse_args(): action="store_true", help="whether to generate the standard results.", ) + parser.add_argument( + "--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type" + ) args = parser.parse_args() print("arg setting: ", args) return ( @@ -42,11 +42,12 @@ def parse_args(): args.batch_size, args.length, args.gen_std, + args.type, ) -def run_model(model, runtime, inputs, n=10): - stub = OnnxStub(model, runtime) +def run_model(model, runtime, inputs, n=10, data_type = "default"): + stub = OnnxStub(model, runtime, matmul_compute_type=data_type) for tensor, input in zip(stub.inputs.values(), inputs, strict=False): tensor.copyin_numpy(input) # stub.tune() @@ -66,17 +67,17 @@ def run_model(model, runtime, inputs, n=10): return outputs -def run_and_compare(name, model, runtime): +def run_and_compare(name, model, runtime, data_type): input_ids = np.load(f"{name}_inputs.npy") position_ids = np.arange(input_ids.shape[-1]) results = np.load(f"{name}_results.npy") - outputs = run_model(model, runtime, (input_ids, position_ids)) + outputs = run_model(model, runtime, (input_ids, position_ids), data_type=data_type) print("outputs abs mean:", abs(outputs).mean()) print("max abs diff:", abs(outputs - results).max()) def start_worker( - name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str ): dist_name = name + "_dist" model = parallel_model(model, world_size, rank) @@ -97,12 +98,12 @@ def start_worker( world_size, rank, ) - run_and_compare(name, model, runtime) + run_and_compare(name, model, runtime, data_type) -def start_single(name, model): +def start_single(name, model, data_type): runtime = backend.CudaRuntime(0) - run_and_compare(name, model, runtime) + run_and_compare(name, model, runtime, data_type) def gen_standard(name, model, voc_size, bs, len): @@ -117,8 +118,10 @@ def gen_standard(name, model, voc_size, bs, len): def main(): - nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() - + nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args() + data_type = "default" if data_type == "fp32" else data_type + if data_type != "tf32": + os.environ["NVIDIA_TF32_OVERRIDE"] = "0" model = onnx.load(model_path) # generate standart output @@ -132,7 +135,7 @@ def main(): # run single process. # use standalone process to isolate cuda. print("run model by single GPU.") - p = mp.Process(target=start_single, args=(name, model)) + p = mp.Process(target=start_single, args=(name, model, data_type)) p.start() p.join() @@ -142,7 +145,7 @@ def main(): workers = [ mp.Process( target=start_worker, - args=(name, world_size, rank, rank % nproc_per_node, model), + args=(name, world_size, rank, rank % nproc_per_node, model, data_type), ) for rank in range(world_size) ] diff --git a/examples/distributed/run_pytorch.py b/examples/distributed/run_pytorch.py index ea2af234..77a8e620 100644 --- a/examples/distributed/run_pytorch.py +++ b/examples/distributed/run_pytorch.py @@ -10,8 +10,6 @@ import os from onnx.external_data_helper import convert_model_to_external_data from onnxsim import simplify -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False def parse_args(): parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.") parser.add_argument( @@ -27,14 +25,17 @@ def parse_args(): const="./", help="whether and where to export onnx file", ) - args = parser.parse_args() + parser.add_argument( + "--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type" + ) args = parser.parse_args() print("arg setting: ", args) return ( args.model, args.batch_size, args.length, - args.export_onnx + args.export_onnx, + args.type, ) @@ -81,7 +82,7 @@ def run_pytorch(torch_model, voc_size, batchsize, len): print("outputs abs mean:", abs(np.array(outputs)).mean()) print(f"average time: {avg_time}") torch.cuda.memory.empty_cache() - np.save("test_results", np.array(outputs)) + np.save("test_results", np.array(outputs, dtype=np.float32)) print("Save input & output as test_inputs.npy and test_results.npy") @@ -164,7 +165,14 @@ def add_value_info_for_constants(model : onnx.ModelProto): def main(): - modelname, batchsize, seqlen, export_path = parse_args() + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + modelname, batchsize, seqlen, export_path, data_type = parse_args() + if data_type == "tf32": + torch.backends.cuda.matmul.allow_tf32 = True + else: + os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + model, voc_size = get_model(modelname) if export_path is not None: filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen) @@ -172,6 +180,8 @@ def main(): param = torch.zeros((batchsize, seqlen), dtype=torch.int) export_onnx(model, param, path, True) + if data_type == "fp16": + model = model.half() run_pytorch(model, voc_size, batchsize, seqlen) if __name__ == "__main__":