From 985d0dee5f90cf1e90c49067bd8155ff714fcacc Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:46:25 +0800 Subject: [PATCH] Kunlun dist op (#225) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * kunlun dist inference fix * kunlun distributed * 添加昆仑芯分布式脚本以及解决运行llama遇到的问题 * set -j8 * format * move run_pytorch.py int o cuda/ * update notes --------- Co-authored-by: weijie01 Co-authored-by: wanghailu Co-authored-by: Haojie Wang --- examples/distributed/__init__.py | 0 .../distributed/{ => bang}/bang_launch.py | 0 .../distributed/{ => cuda}/cuda_launch.py | 0 .../distributed/{ => cuda}/launch_kvcache.py | 0 .../distributed/{ => cuda}/run_pytorch.py | 0 examples/distributed/kunlun/export_onnx.sh | 14 + examples/distributed/kunlun/kunlun_launch.py | 280 ++++++++++++++++++ examples/distributed/kunlun/launch.sh | 36 +++ examples/distributed/kunlun/llama_launch.sh | 35 +++ examples/distributed/kunlun/run_pytorch.py | 245 +++++++++++++++ examples/distributed/launch_kunlun.py | 213 ------------- examples/distributed/parallel_opt.py | 1 - include/kunlun/kunlun_runtime.h | 4 +- include/kunlun/xccl_communicator.h | 4 +- pyinfinitensor/src/pyinfinitensor/onnx.py | 2 +- src/kernels/kunlun/element_wise.cc | 25 +- 16 files changed, 622 insertions(+), 237 deletions(-) create mode 100644 examples/distributed/__init__.py rename examples/distributed/{ => bang}/bang_launch.py (100%) rename examples/distributed/{ => cuda}/cuda_launch.py (100%) rename examples/distributed/{ => cuda}/launch_kvcache.py (100%) rename examples/distributed/{ => cuda}/run_pytorch.py (100%) create mode 100644 examples/distributed/kunlun/export_onnx.sh create mode 100644 examples/distributed/kunlun/kunlun_launch.py create mode 100644 examples/distributed/kunlun/launch.sh create mode 100644 examples/distributed/kunlun/llama_launch.sh create mode 100644 examples/distributed/kunlun/run_pytorch.py delete mode 100644 examples/distributed/launch_kunlun.py diff --git a/examples/distributed/__init__.py b/examples/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/distributed/bang_launch.py b/examples/distributed/bang/bang_launch.py similarity index 100% rename from examples/distributed/bang_launch.py rename to examples/distributed/bang/bang_launch.py diff --git a/examples/distributed/cuda_launch.py b/examples/distributed/cuda/cuda_launch.py similarity index 100% rename from examples/distributed/cuda_launch.py rename to examples/distributed/cuda/cuda_launch.py diff --git a/examples/distributed/launch_kvcache.py b/examples/distributed/cuda/launch_kvcache.py similarity index 100% rename from examples/distributed/launch_kvcache.py rename to examples/distributed/cuda/launch_kvcache.py diff --git a/examples/distributed/run_pytorch.py b/examples/distributed/cuda/run_pytorch.py similarity index 100% rename from examples/distributed/run_pytorch.py rename to examples/distributed/cuda/run_pytorch.py diff --git a/examples/distributed/kunlun/export_onnx.sh b/examples/distributed/kunlun/export_onnx.sh new file mode 100644 index 00000000..aa104c55 --- /dev/null +++ b/examples/distributed/kunlun/export_onnx.sh @@ -0,0 +1,14 @@ + export HF_ENDPOINT=https://hf-mirror.com + +models=("bert" "gpt2" "llama") +batch_size=(1 32) +seq_len=(100 500) +nproc=(1 2 4) + +for model in "${models[@]}"; do + for bs in "${batch_size[@]}"; do + for len in "${seq_len[@]}"; do + python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" --export_only + done + done +done diff --git a/examples/distributed/kunlun/kunlun_launch.py b/examples/distributed/kunlun/kunlun_launch.py new file mode 100644 index 00000000..9126fa19 --- /dev/null +++ b/examples/distributed/kunlun/kunlun_launch.py @@ -0,0 +1,280 @@ +import sys +sys.path.append('../') + +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model +from functools import wraps + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, choices=["gpt2", "bert", "llama"], help="name of model." + ) + parser.add_argument( + "--model", type=str, default="", help="path to the ONNX model file." + ) + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + parser.add_argument( + "--run_single", + default=False, + action="store_true", + help="whether run model with single process with standard inputs" + ) + parser.add_argument( + "--input_dir", + default="./", + help="path to save model input data" + ) + parser.add_argument( + "--result_dir", + default="./", + help="path to save model standard output" + ) + parser.add_argument( + "--internal_model_dir", + default="./", + help="path to save internal onnx model for parallel run" + ) + args = parser.parse_args() + + # check path, mkdir if not exist + check_exists(args.input_dir) + check_exists(args.result_dir) + check_exists(args.internal_model_dir) + + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.gen_std, + args.run_single, + args.input_dir, + args.result_dir, + args.internal_model_dir + ) + + +""" +utils function for this scripts +""" +def check_exists(path: str): + if not os.path.exists(path): + os.makedirs(path) + +def np_assert(base, test, rtol=1e-2, atol=1e-1): + # np.testing.assert_allclose(test, base, rtol, atol) + print("max abs diff:", abs(base - test).max()) + + +""" +Perf wrapper, run function n times +then average +""" +def perf_it(n): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # warmup + for _ in range(n): + func(*args, **kwargs) + + t_total = 0 + for _ in range(n): + t0 = time.time() + func(*args, **kwargs) + t1 = time.time() + t_total += t1 - t0 + avg_time = (t_total) / n + print(f"Avg runtime of {n} time is {avg_time:.6f} seconds") + return avg_time + return wrapper + return decorator + + +""" +Run InfiniTensor model with Standard input +check=True: check with standard output gen by pytorch +perf=True: run n times to get avg time +""" +def run_model(task_name, + model, + runtime, + world_size=1, + rank=0, + n=10, + check=True, + perf=True): + + stub = OnnxStub(model, runtime, + use_naive_allocator=True \ + if task_name == "llama" else False) + + # load in Onnx model inputs + def load_inputs(stub: OnnxStub): + # check exists + inputs = [] + for i, (name, tensor) in enumerate(stub.inputs.items()): + input_path = os.path.join(input_dir, \ + f"{task_name}_input_{i}.npy") + print(input_path) + if os.path.exists(input_path): + input = np.load(input_path) + else : + raise KeyError(f"{i} th input of model not exists") + # check shape + if all(x == y for x,y in zip(input.shape, tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + + load_inputs(stub) + # stub.tune() + stub.run() + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + + # check output results with standard output + if check: + st_output_path = os.path.join(result_dir, \ + f"{task_name}_output.npy") + assert os.path.exists(st_output_path) , \ + "standard output not exists" + st_output = np.load(st_output_path) + if np.isnan(output).any(): + print("Nan in output") + exit() + np_assert(st_output, output) + + # perf + if perf: + @perf_it(n) + def perf_infinitensor(stub: OnnxStub): + stub.run() + perf_infinitensor(stub) + + return output + + +""" +Start a worker in Parallel +""" +def start_worker(name: str, + world_size: int, + rank: int, + local_rank: int, + model: onnx.ModelProto): + + dist_name = name + "_dist" + # partial a onnx model to world_size part + model = parallel_model(model, world_size, rank) + onnx.save(model, os.path.join(internal_model_dir, \ + f"{dist_name}_rank{rank}.onnx"), save_as_external_data=True) + runtime = backend.KUNLUNRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_model(name, model, runtime, world_size, rank) + + +""" +generate standard input/output with +sigle card run +""" +def gen_standard(task_name: str, model: onnx.ModelProto): + runtime = backend.KUNLUNRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + # generate random input for model + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(os.path.join(input_dir, \ + f"{task_name}_input_{i}.npy"), input) + stub.run() + # print(stub.outputs) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + if np.isnan(output).any(): + print("Nan in output") + exit() + np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output) + + +def main(): + + global input_dir, result_dir, internal_model_dir + + nnodes, nproc_per_node, task_name, \ + model_path, gen_std, run_single, \ + input_dir, result_dir, internal_model_dir = parse_args() + + # load input onnx model + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + gen_standard(task_name, model) + return + + if run_single: + print("Run model by one GPU card.") + runtime = backend.KUNLUNRuntime(0) + run_model(task_name, model, runtime) + return + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"Run model by {world_size} GPU in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(task_name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/kunlun/launch.sh b/examples/distributed/kunlun/launch.sh new file mode 100644 index 00000000..f0bb81be --- /dev/null +++ b/examples/distributed/kunlun/launch.sh @@ -0,0 +1,36 @@ +export HF_ENDPOINT=https://hf-mirror.com + +# models=("bert" "gpt2" "llama") +models=("bert" "gpt2") +batch_size=(1 32) +seq_len=(100 500) +nproc=(1 2 4) + +results_dir="results" + +if [ -d "$results_dir" ]; then + echo "directory ./$results_dir exists" +else + mkdir -p "$results_dir" + echo "mkdir $results_dir, logs saved there" +fi + + +for model in "${models[@]}"; do + for bs in "${batch_size[@]}"; do + for len in "${seq_len[@]}"; do + # run pytorch model + echo "Run pytorch $model with batch_size=$bs length=$len ." + python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" #> results/"$model"_"$bs"_"$len"_pytorch + for n in "${nproc[@]}"; do + # run infinitensor + echo "Run $n parallel infinitensor "$model" with batch_size=$bs and length=$len ." + python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n # >> results/"$model"_"$bs"_"$len"_infini + # delete internal files + find ./ -type f -name "*.onnx" -delete + find ./ -type f -name "*.pb" -delete + done + find ./ -type f -name "*.npy" -delete + done + done +done diff --git a/examples/distributed/kunlun/llama_launch.sh b/examples/distributed/kunlun/llama_launch.sh new file mode 100644 index 00000000..2500eea8 --- /dev/null +++ b/examples/distributed/kunlun/llama_launch.sh @@ -0,0 +1,35 @@ +export HF_ENDPOINT=https://hf-mirror.com + +# models=("bert" "gpt2" "llama") +models=("llama") +batch_size=(1 ) +seq_len=(100 500) +nproc=(1 2 4) + +results_dir="results" + +if [ -d "$results_dir" ]; then + echo "directory ./$results_dir exists" +else + mkdir -p "$results_dir" + echo "mkdir $results_dir, logs saved there" +fi + + +for model in "${models[@]}"; do + for bs in "${batch_size[@]}"; do + for len in "${seq_len[@]}"; do + echo "Run pytorch llama with batch_size="$bs" and length="$len"" + python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" + for n in "${nproc[@]}"; do + # run pytorch model + echo "Run infinitensor llama with batch_size="$bs" and length="$len" and nproc="$n"." + python kunlun_launch.py --name llama --model ../models/llama/llama_"$bs"_"$len"_fp32.onnx --nproc_per_node=$n + # delete internal files + find ./ -type f -name "*.onnx" -delete + find ./ -type f -name "*0c" -delete + done + find ./ -type f -name "*.npy" -delete + done + done +done diff --git a/examples/distributed/kunlun/run_pytorch.py b/examples/distributed/kunlun/run_pytorch.py new file mode 100644 index 00000000..5cdf627d --- /dev/null +++ b/examples/distributed/kunlun/run_pytorch.py @@ -0,0 +1,245 @@ +import argparse +import torch +from transformers import BertModel, BertConfig +from transformers import GPT2Model, GPT2Config +from transformers import OPTModel, OPTConfig +from transformers import LlamaModel, LlamaConfig +import time +import numpy as np +import onnx +import os +import sys +from onnx.external_data_helper import convert_model_to_external_data +from onnxsim import simplify + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +def parse_args(): + parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.") + parser.add_argument( + "--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--export_onnx", + type=str, + nargs="?", + default=None, + const="./", + help="whether and where to export onnx file", + ) + parser.add_argument( + "--input_dir", + type=str, + default="./", + help="path to save pytorch model input data" + ) + parser.add_argument( + "--result_dir", + type=str, + default="./", + help="path to save pytorch model output data" + ) + parser.add_argument( + "--export_only", + action="store_true" + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.model, + args.batch_size, + args.length, + args.export_onnx, + args.input_dir, + args.result_dir, + args.export_only + ) + + +def get_model(modelname): + if modelname == "bert": + model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini + voc_size = BertConfig().vocab_size + elif modelname == "gpt2": + model = GPT2Model.from_pretrained("gpt2") + voc_size = GPT2Config().vocab_size + elif modelname == "opt": + model = OPTModel.from_pretrained("./opt-125m") + voc_size = OPTConfig().vocab_size + elif modelname == "llama": + model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf") + voc_size = LlamaConfig().vocab_size + else : + raise KeyError(modelname) + + model = model.eval() + return model, voc_size + +def run_pytorch(torch_model, voc_size, batchsize, len, model_name): + data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32) + np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data) + inputs = torch.from_numpy(data).to("cuda") + torch_model = torch_model.to("cuda") + + n_iter = 10 + with torch.no_grad(): + for _ in range(10): + outputs = torch_model(inputs) + torch.cuda.synchronize() + begin = time.time() + with torch.no_grad(): + for _ in range(n_iter): + torch.cuda.synchronize() + outputs = torch_model(inputs) + # + torch.cuda.synchronize() + torch.cuda.synchronize() + end = time.time() + + avg_time = (end - begin) / n_iter + outputs = outputs.last_hidden_state.to("cpu") + print("outputs abs mean:", abs(np.array(outputs)).mean()) + print(f"average time: {avg_time}") + torch.cuda.memory.empty_cache() + np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \ + np.array(outputs)) + print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy") + + +def export_onnx(model_name, model, data, path, extern=False): + # torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True) + + if model_name != "llama": + onnx_model = onnx.load(path) + onnx_model, check = simplify(onnx_model, + skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer']) + # skipped_optimizers=['fuse_qkv']) + assert check + add_value_info_for_constants(onnx_model) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + if extern: + extern_path = path.replace('.onnx', '.pb') + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + onnx_model, + all_tensors_to_one_file=True, + location=extern_path.split("/")[-1], + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(onnx_model, path) + else: + sys.path.append("onnxsim_large_model") + from onnx_utils import set_onnx_input_shape + from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model + + in_model_path = path + out_model_path = in_model_path[:-5] + ".sim.onnx" + + onnx_model = onnx.load(in_model_path) + print(f"load model from {in_model_path} success") + + size_th_bytes = 1024 * 1024 + onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes) + print("compress model success") + + onnx_model = set_onnx_input_shape(onnx_model, "") + tensor_size_threshold = f"1024KB" + skipped_optimizers = [] + skipped_optimizers.append("eliminate_duplicate_initializer") + onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers, + tensor_size_threshold=tensor_size_threshold) + if not check: + raise ValueError(f"simplify compressed model {in_model_path} failed") + + print(f"simplify model success") + + onnx_model = uncompress_onnx_model(onnx_model, removed_inits) + print(f"uncompress model success") + + add_value_info_for_constants(onnx_model) + + onnx.save(onnx_model, out_model_path, save_as_external_data=True) + + +def add_value_info_for_constants(model : onnx.ModelProto): + """ + Currently onnx.shape_inference doesn't use the shape of initializers, so add + that info explicitly as ValueInfoProtos. + Mutates the model. + Args: + model: The ModelProto to update. + """ + # All (top-level) constants will have ValueInfos before IRv4 as they are all inputs + if model.ir_version < 4: + return + + def add_const_value_infos_to_graph(graph : onnx.GraphProto): + inputs = {i.name for i in graph.input} + existing_info = {vi.name: vi for vi in graph.value_info} + for init in graph.initializer: + # Check it really is a constant, not an input + if init.name in inputs: + continue + + # The details we want to add + elem_type = init.data_type + shape = init.dims + + # Get existing or create new value info for this constant + vi = existing_info.get(init.name) + if vi is None: + vi = graph.value_info.add() + vi.name = init.name + + # Even though it would be weird, we will not overwrite info even if it doesn't match + tt = vi.type.tensor_type + if tt.elem_type == onnx.TensorProto.UNDEFINED: + tt.elem_type = elem_type + if not tt.HasField("shape"): + # Ensure we set an empty list if the const is scalar (zero dims) + tt.shape.dim.extend([]) + for dim in shape: + tt.shape.dim.add().dim_value = dim + + # Handle subgraphs + for node in graph.node: + for attr in node.attribute: + # Ref attrs refer to other attrs, so we don't need to do anything + if attr.ref_attr_name != "": + continue + + if attr.type == onnx.AttributeProto.GRAPH: + add_const_value_infos_to_graph(attr.g) + if attr.type == onnx.AttributeProto.GRAPHS: + for g in attr.graphs: + add_const_value_infos_to_graph(g) + + + return add_const_value_infos_to_graph(model.graph) + + +def main(): + global input_dir, result_dir + + modelname, batchsize, seqlen, \ + export_path, input_dir, result_dir, export_only = parse_args() + + model, voc_size = get_model(modelname) # pytorch model + + if export_path is not None: + os.makedirs(export_path, exist_ok=True) + filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen) + path = os.path.join(export_path, filename) + param = torch.zeros((batchsize, seqlen), dtype=torch.int) + export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model + if export_only: + return + + run_pytorch(model, voc_size, batchsize, seqlen, modelname) + +if __name__ == "__main__": + main() diff --git a/examples/distributed/launch_kunlun.py b/examples/distributed/launch_kunlun.py deleted file mode 100644 index e8c1a0ab..00000000 --- a/examples/distributed/launch_kunlun.py +++ /dev/null @@ -1,213 +0,0 @@ -import argparse -import os -import time -import multiprocessing as mp -from pyinfinitensor.onnx import OnnxStub, backend -import onnx -from onnx.external_data_helper import convert_model_to_external_data -from onnx.shape_inference import infer_shapes_path -import numpy as np -from parallel_opt import parallel_model - -st_input_dir = "standard/inputs/" -st_output_dir = "standard/outputs/" - -def parse_args(): - parser = argparse.ArgumentParser(description="launch distributed infinitensor") - parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") - parser.add_argument( - "--nproc_per_node", type=int, default=2, help="number of processes per node" - ) - parser.add_argument( - "--name", type=str, default="test", help="name of this instance." - ) - parser.add_argument( - "--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file." - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size.") - parser.add_argument("--length", type=int, default=1, help="sequence length.") - parser.add_argument( - "--gen_std", - default=False, - action="store_true", - help="whether to generate the standard results.", - ) - parser.add_argument( - "--run_single", - default=False, - action="store_true", - help="whether run model with single process with standard inputs" - ) - args = parser.parse_args() - print("arg setting: ", args) - return ( - args.num_nodes, - args.nproc_per_node, - args.name, - args.model, - args.batch_size, - args.length, - args.gen_std, - args.run_single - ) - - -def run_model(model, runtime, world_size=1, rank=0, n=10): - stub = OnnxStub(model, runtime) - load_inputs(stub, world_size, rank) - # stub.tune() - stub.run() - # get outputs - time.sleep(0.01) - outputs = next(stub.outputs.values().__iter__()).copyout_numpy() - - # bench - begin = time.time() - for _ in range(n): - stub.run() - end = time.time() - avg_time = (end - begin) / n - print(f"average time: {avg_time}") - return outputs - - - -def run_and_compare(name, model, runtime, world_size=1, rank = 0): - results = np.load(os.path.join(st_output_dir,f"output.npy")) - outputs = run_model(model, runtime, world_size, rank) - print(outputs[:100]) - if np.isnan(outputs).any(): - print("Nan in output") - print("answer argmax:", np.argmax(results)) - print("output argmax:", np.argmax(outputs)) - #np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3) - getDiff(results, outputs) - - -def start_worker( - name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto -): - dist_name = name + "_dist" - model = parallel_model(model, world_size, rank) - extern_path = f"./{dist_name}_rank{rank}.pb" - if os.path.exists(extern_path): - os.remove(extern_path) - onnx.save_model( - model, - f"./{dist_name}_rank{rank}.onnx", - save_as_external_data=True, - location=extern_path, - ) - infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") - runtime = backend.KUNLUNRuntime(local_rank) - # print("init comm") - runtime.init_comm( - dist_name, - world_size, - rank, - ) - run_and_compare(name, model, runtime, world_size, rank) - - -def start_single(name, model): - runtime = backend.KUNLUNRuntime(0) - run_and_compare(name, model, runtime) - - -def generate_input_output(model): - runtime = backend.KUNLUNRuntime(0) - stub = OnnxStub(model, runtime) - position_id = 0 - for i, (name, tensor) in enumerate(stub.inputs.items()): - input = tensor.copyout_numpy() - if np.issubdtype(input.dtype, np.integer): - if input.size == 1: - # input = np.array([position_id]) - input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) - else: - input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) - elif input.dtype == np.bool_: - input = np.random.randint(0,2,size=input.shape) > 0 - else: - if i == 0: - input = np.ones(input.shape).astype(input.dtype) - position_id = input.shape[-1] - 1 - else: - input = np.random.rand(*input.shape).astype(input.dtype) - tensor.copyin_numpy(input) - np.save(os.path.join(st_input_dir, f"input_{i}"), input) - stub.run() - # print(stub.outputs) - time.sleep(0.01) - output = next(stub.outputs.values().__iter__()).copyout_numpy() - print(output[:100]) - if np.isnan(output).any(): - print("Nan in output") - np.save(os.path.join(st_output_dir, f"output"), output) - - -def load_inputs(stub, world_size=1, rank=0): - for i, (name, tensor) in enumerate(stub.inputs.items()): - input = np.load(os.path.join(st_input_dir, f"input_{i}.npy")) - if all(x == y for x,y in zip(input.shape,tensor.shape())): - tensor.copyin_numpy(input) - else: - tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) - - -def getDiff(base, test): - absolute_diff = np.abs(np.subtract(base, test)) - max_absolute_diff = np.max(absolute_diff) - - baseCopy = base.astype(np.float64).ravel() - testCopy = test.astype(np.float64).ravel() - upValue = np.sum(np.abs(baseCopy - testCopy)) - downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9) - max_relative_diff = upValue / downValue - print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}") - - return max_absolute_diff, max_relative_diff - - -def main(): - nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args() - - model = onnx.load(model_path) - - # generate standart output - if gen_std: - print("Generate inputs and outputs.") - p = mp.Process(target=generate_input_output, args=[model]) - p.start() - p.join() - return - - # # run single process. - # # use standalone process to isolate cuda. - if run_single: - print("run model by single GPU.") - p = mp.Process(target=start_single, args=(name, model)) - p.start() - p.join() - return - - # run distributed parallel. - world_size = nnodes * nproc_per_node - print(f"run model by {world_size} GPU in parallel.") - workers = [ - mp.Process( - target=start_worker, - args=(name, world_size, rank, rank % nproc_per_node, model), - ) - for rank in range(world_size) - ] - - for w in workers: - w.start() - - for w in workers: - w.join() - - -if __name__ == "__main__": - main() diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 804e48c6..4985f3aa 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -110,7 +110,6 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): s_dim = 0 elif in_plc.dim == 2: s_dim = 1 - assert s_dim != -1 assert out_dims[s_dim] % tp_world_size == 0, out_dims out_dims[s_dim] //= tp_world_size diff --git a/include/kunlun/kunlun_runtime.h b/include/kunlun/kunlun_runtime.h index 0c175158..8defcf9d 100644 --- a/include/kunlun/kunlun_runtime.h +++ b/include/kunlun/kunlun_runtime.h @@ -21,7 +21,7 @@ class KUNLUNRuntimeObj : public RuntimeObj { ctx = xdnn::create_context(); // 10GB for Longformer // size_t longformerNum = 3lu * (1 << 30); - size_t workspaceSize = 3llu << 30; // 3 GB + size_t workspaceSize = 2llu << 30; // 2 GB KUNLUNPtr wkspacePtr = alloc(workspaceSize); workspace = make_ref>(wkspacePtr, workspaceSize); @@ -42,7 +42,7 @@ class KUNLUNRuntimeObj : public RuntimeObj { KUNLUNPtr alloc(size_t size) override { void *ptr; checkKUNLUNError( - xpu_malloc_ex((void **)&ptr, size, XPUMemoryKind::XPU_MEM_MAIN)); + xpu_malloc((void **)&ptr, size, XPUMemoryKind::XPU_MEM_HBM)); return ptr; } void dealloc(void *ptr) override { xpu_free(ptr); } diff --git a/include/kunlun/xccl_communicator.h b/include/kunlun/xccl_communicator.h index 6e9c31d0..5d995aa4 100644 --- a/include/kunlun/xccl_communicator.h +++ b/include/kunlun/xccl_communicator.h @@ -34,8 +34,8 @@ class XcclCommunicatorObj final : public CommunicatorObj { auto begin = std::chrono::steady_clock::now(); while (!std::filesystem::exists(filePath)) { auto now = std::chrono::steady_clock::now(); - _IT_ASSERT_2(now < begin + std::chrono::seconds(10), - "time limit (10s) exceeded."); + _IT_ASSERT_2(now < begin + std::chrono::seconds(100), + "time limit (100s) exceeded."); std::this_thread::sleep_for(std::chrono::milliseconds(100)); } std::ifstream ifs(filePath, std::ios::binary); diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index fc1e0bbc..3727e63a 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -967,7 +967,7 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) - elif node.op_type == "Constant": + elif node.op_type in ["Constant", "ConstantOfShape"]: output_name = node.output[0] attributes = _parse_attribute(node) tensor = attributes["value"] diff --git a/src/kernels/kunlun/element_wise.cc b/src/kernels/kunlun/element_wise.cc index 665ea56a..287231a1 100644 --- a/src/kernels/kunlun/element_wise.cc +++ b/src/kernels/kunlun/element_wise.cc @@ -97,11 +97,14 @@ class DivXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bSize = op->getInputs(1)->size(); auto bDim = op->getInputs(1)->getDims(); - auto dtype = op->getDType(); + // op input a, b is scalar while aDim and b Dim is empty if (bDim.size() == 0) { bDim.push_back(1); } + if (aDim.size() == 0) { + aDim.push_back(1); + } if (aSize == bSize) { // Do ElementWise Sub with no broadcast @@ -109,23 +112,9 @@ class DivXdnn : public KUNLUNKernelWithoutConfig { (float *)aData, (float *)bData, (float *)cData, aSize)); } else { - // Do broadcast div - Shape aligned = infer_broadcast(aDim, bDim); - if (aligned == aDim) { - // BData need to be broadcasted - checkKUNLUNError(xdnn::broadcast_div( - context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim)); - } else { - // Use workspace to broadcast aData - KUNLUNPtr wks = context->getWorkspace(bSize * dtype.getSize()); - checkKUNLUNError(xdnn::broadcast( - context->KUNLUNHandle(), (float *)aData, (float *)wks, aDim, - bDim)); - checkKUNLUNError(xdnn::div(context->KUNLUNHandle(), - (float *)wks, (float *)bData, - (float *)cData, bSize)); - } + checkKUNLUNError(xdnn::broadcast_div( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim)); } return; }