From 32a13b77607eba50ceedf6fd747689f0b506501e Mon Sep 17 00:00:00 2001 From: weijie01 Date: Tue, 2 Apr 2024 17:15:08 +0800 Subject: [PATCH] kunlun distributed --- CMakeLists.txt | 4 +- cmake/FindXCCL.cmake | 2 +- examples/distributed/kunlun/kunlun_launch.py | 278 ++++++++++++++++++ examples/distributed/kunlun/kunlun_launch2.py | 215 ++++++++++++++ examples/distributed/kunlun/run_pytorch.py | 208 +++++++++++++ examples/distributed/kunlun/run_pytorch.sh | 17 ++ include/core/kernel.h | 2 +- include/core/perf_engine.h | 2 +- src/kernels/kunlun/all_reduce.cc | 17 +- 9 files changed, 732 insertions(+), 13 deletions(-) create mode 100644 examples/distributed/kunlun/kunlun_launch.py create mode 100644 examples/distributed/kunlun/kunlun_launch2.py create mode 100644 examples/distributed/kunlun/run_pytorch.py create mode 100644 examples/distributed/kunlun/run_pytorch.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index ccacac1c..1dbb007b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -285,8 +285,8 @@ if(USE_KUNLUN) message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}") include_directories("${KUNLUN_HOME}/include/") - find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/so/") - find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/so/") + find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64/") + find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/lib64/") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) diff --git a/cmake/FindXCCL.cmake b/cmake/FindXCCL.cmake index fa4b3b00..28dd1f14 100644 --- a/cmake/FindXCCL.cmake +++ b/cmake/FindXCCL.cmake @@ -9,7 +9,7 @@ find_path(XCCL_INCLUDE_DIRS # ${XCCL_INCLUDE_DIR} HINTS XCCL_INCLUDE_DIR) find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR} - NAMES so/libbkcl.so + NAMES lib64/libbkcl.so HINTS XCCL_LIB_DIR) message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}") diff --git a/examples/distributed/kunlun/kunlun_launch.py b/examples/distributed/kunlun/kunlun_launch.py new file mode 100644 index 00000000..7e69e4b3 --- /dev/null +++ b/examples/distributed/kunlun/kunlun_launch.py @@ -0,0 +1,278 @@ +import sys +sys.path.append('../') + +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model +from functools import wraps + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, default="", help="path to the ONNX model file." + ) + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + parser.add_argument( + "--run_single", + default=False, + action="store_true", + help="whether run model with single process with standard inputs" + ) + parser.add_argument( + "--input_dir", + default="./", + help="path to save model input data" + ) + parser.add_argument( + "--result_dir", + default="./", + help="path to save model standard output" + ) + parser.add_argument( + "--internal_model_dir", + default="./", + help="path to save internal onnx model for parallel run" + ) + args = parser.parse_args() + + # check path, mkdir if not exist + check_exists(args.input_dir) + check_exists(args.result_dir) + check_exists(args.internal_model_dir) + + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.gen_std, + args.run_single, + args.input_dir, + args.result_dir, + args.internal_model_dir + ) + + +""" +utils function for this scripts +""" +def check_exists(path: str): + if not os.path.exists(path): + os.makedirs(path) + +def np_assert(base, test, rtol=1e-2, atol=1e-1): + # np.testing.assert_allclose(test, base, rtol, atol) + print("max abs diff:", abs(base - test).max()) + + +""" +Perf wrapper, run function n times +then average +""" +def perf_it(n): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # warmup + for _ in range(n): + func(*args, **kwargs) + + t_total = 0 + for _ in range(n): + t0 = time.time() + func(*args, **kwargs) + t1 = time.time() + t_total += t1 - t0 + avg_time = (t_total) / n + print(f"Avg runtime of {n} time is {avg_time:.6f} seconds") + return avg_time + return wrapper + return decorator + + +""" +Run InfiniTensor model with Standard input +check=True: check with standard output gen by pytorch +perf=True: run n times to get avg time +""" +def run_model(task_name, + model, + runtime, + world_size=1, + rank=0, + n=10, + check=True, + perf=True): + + stub = OnnxStub(model, runtime) + + # load in Onnx model inputs + def load_inputs(stub: OnnxStub): + # check exists + inputs = [] + for i, (name, tensor) in enumerate(stub.inputs.items()): + input_path = os.path.join(input_dir, \ + f"{task_name}_input_{i}.npy") + print(input_path) + if os.path.exists(input_path): + input = np.load(input_path) + else : + raise KeyError(f"{i} th input of model not exists") + # check shape + if all(x == y for x,y in zip(input.shape, tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + + load_inputs(stub) + # stub.tune() + stub.run() + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + + # check output results with standard output + if check: + st_output_path = os.path.join(result_dir, \ + f"{task_name}_output.npy") + assert os.path.exists(st_output_path) , \ + "standard output not exists" + st_output = np.load(st_output_path) + if np.isnan(output).any(): + print("Nan in output") + exit() + np_assert(st_output, output) + + # perf + if perf: + @perf_it(n) + def perf_infinitensor(stub: OnnxStub): + stub.run() + perf_infinitensor(stub) + + return output + + +""" +Start a worker in Parallel +""" +def start_worker(name: str, + world_size: int, + rank: int, + local_rank: int, + model: onnx.ModelProto): + + dist_name = name + "_dist" + # partial a onnx model to world_size part + model = parallel_model(model, world_size, rank) + onnx.save(model, os.path.join(internal_model_dir, \ + f"{dist_name}_rank{rank}.onnx")) + runtime = backend.KUNLUNRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_model(name, model, runtime, world_size, rank) + + +""" +generate standard input/output with +sigle card run +""" +def gen_stardard(task_name: str, model: onnx.ModelProto): + runtime = backend.KUNLUNRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + # generate random input for model + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(os.path.join(input_dir, \ + f"{task_name}_input_{i}.npy"), input) + stub.run() + # print(stub.outputs) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + if np.isnan(output).any(): + print("Nan in output") + exit() + np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output) + + +def main(): + + global input_dir, result_dir, internal_model_dir + + nnodes, nproc_per_node, task_name, \ + model_path, gen_std, run_single, \ + input_dir, result_dir, internal_model_dir = parse_args() + + # load input onnx model + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + gen_stardard(task_name, model) + return + + if run_single: + print("Run model by one GPU card.") + runtime = backend.KUNLUNRuntime(0) + run_model(task_name, model, runtime) + return + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"Run model by {world_size} GPU in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(task_name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/kunlun/kunlun_launch2.py b/examples/distributed/kunlun/kunlun_launch2.py new file mode 100644 index 00000000..aa433529 --- /dev/null +++ b/examples/distributed/kunlun/kunlun_launch2.py @@ -0,0 +1,215 @@ +import sys +sys.path.append("../") +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model + +st_input_dir = ".cache/input/" +st_output_dir = ".cache/output/" + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + parser.add_argument( + "--run_single", + default=False, + action="store_true", + help="whether run model with single process with standard inputs" + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + args.run_single + ) + + +def run_model(model, runtime, world_size=1, rank=0, n=10): + stub = OnnxStub(model, runtime) + load_inputs(stub, world_size, rank) + # stub.tune() + stub.run() + # get outputs + time.sleep(0.01) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() + + # bench + begin = time.time() + for _ in range(n): + stub.run() + end = time.time() + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + + + +def run_and_compare(name, model, runtime, world_size=1, rank = 0): + results = np.load(os.path.join(st_output_dir, "test_output.npy")) + outputs = run_model(model, runtime, world_size, rank) + print(outputs[:100]) + if np.isnan(outputs).any(): + print("Nan in output") + print("answer argmax:", np.argmax(results)) + print("output argmax:", np.argmax(outputs)) + #np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3) + getDiff(results, outputs) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + onnx.save_model( + model, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, + location=extern_path, + ) + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + runtime = backend.KUNLUNRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_and_compare(name, model, runtime, world_size, rank) + + +def start_single(name, model): + runtime = backend.KUNLUNRuntime(0) + run_and_compare(name, model, runtime) + + +def generate_input_output(model): + runtime = backend.KUNLUNRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + # input = np.array([position_id]) + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(os.path.join(st_input_dir, f"input_{i}"), input) + stub.run() + # print(stub.outputs) + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + print(output[:100]) + if np.isnan(output).any(): + print("Nan in output") + np.save(os.path.join(st_output_dir, f"output"), output) + + +def load_inputs(stub, world_size=1, rank=0): + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = np.load(os.path.join(st_input_dir, f"test_input_{name}.npy")) + if all(x == y for x,y in zip(input.shape,tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + + +def getDiff(base, test): + absolute_diff = np.abs(np.subtract(base, test)) + max_absolute_diff = np.max(absolute_diff) + + baseCopy = base.astype(np.float64).ravel() + testCopy = test.astype(np.float64).ravel() + upValue = np.sum(np.abs(baseCopy - testCopy)) + downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9) + max_relative_diff = upValue / downValue + print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}") + + return max_absolute_diff, max_relative_diff + + +def main(): + nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args() + + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + p = mp.Process(target=generate_input_output, args=[model]) + p.start() + p.join() + return + + # # run single process. + # # use standalone process to isolate cuda. + if run_single: + print("run model by single GPU.") + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + return + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"run model by {world_size} GPU in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/distributed/kunlun/run_pytorch.py b/examples/distributed/kunlun/run_pytorch.py new file mode 100644 index 00000000..77e58b60 --- /dev/null +++ b/examples/distributed/kunlun/run_pytorch.py @@ -0,0 +1,208 @@ +import argparse +import torch +from transformers import BertModel, BertConfig +from transformers import GPT2Model, GPT2Config +from transformers import OPTModel, OPTConfig +from transformers import LlamaModel, LlamaConfig +import time +import numpy as np +import onnx +import os +from onnx.external_data_helper import convert_model_to_external_data +from onnxsim import simplify + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +def parse_args(): + parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.") + parser.add_argument( + "--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--export_onnx", + type=str, + nargs="?", + default=None, + const="./", + help="whether and where to export onnx file", + ) + parser.add_argument( + "--input_dir", + type=str, + default="./", + help="path to save pytorch model input data" + ) + parser.add_argument( + "--result_dir", + type=str, + default="./", + help="path to save pytorch model output data" + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.model, + args.batch_size, + args.length, + args.export_onnx, + args.input_dir, + args.result_dir + ) + + +def get_model(modelname): + if modelname == "bert": + model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini + voc_size = BertConfig().vocab_size + elif modelname == "gpt2": + model = GPT2Model.from_pretrained("gpt2") + voc_size = GPT2Config().vocab_size + elif modelname == "opt": + model = OPTModel.from_pretrained("./opt-125m") + voc_size = OPTConfig().vocab_size + elif modelname == "llama": + model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf") + voc_size = LlamaConfig().vocab_size + else : + raise KeyError(modelname) + + model = model.eval() + return model, voc_size + +def run_pytorch(torch_model, voc_size, batchsize, len, model_name): + data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32) + np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data) + inputs = torch.from_numpy(data).to("cuda") + torch_model = torch_model.to("cuda") + + n_iter = 20 + with torch.no_grad(): + for _ in range(10): + outputs = torch_model(inputs) + torch.cuda.synchronize() + begin = time.time() + with torch.no_grad(): + for _ in range(n_iter): + torch.cuda.synchronize() + outputs = torch_model(inputs) + # + torch.cuda.synchronize() + torch.cuda.synchronize() + end = time.time() + + avg_time = (end - begin) / n_iter + outputs = outputs.last_hidden_state.to("cpu") + print("outputs abs mean:", abs(np.array(outputs)).mean()) + print(f"average time: {avg_time}") + torch.cuda.memory.empty_cache() + np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \ + np.array(outputs)) + print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy") + + +def export_onnx(model_name, model, data, path, extern=False): + torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True) + onnx_model = onnx.load(path) + # onnx_model, check = simplify(onnx_model, + # skip_shape_inference=True, + # skipped_optimizers=['eliminate_duplicate_initializer']) + if model_name == "gpt2": + onnx_model, check = simplify(onnx_model, + skip_shape_inference=True, + skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer']) + else : + onnx_model, check = simplify(onnx_model, + skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer']) + assert check + add_value_info_for_constants(onnx_model) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + if extern: + extern_path = path.replace('.onnx', '.pb') + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + onnx_model, + all_tensors_to_one_file=True, + location=extern_path.split("/")[-1], + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(onnx_model, path) + +def add_value_info_for_constants(model : onnx.ModelProto): + """ + Currently onnx.shape_inference doesn't use the shape of initializers, so add + that info explicitly as ValueInfoProtos. + Mutates the model. + Args: + model: The ModelProto to update. + """ + # All (top-level) constants will have ValueInfos before IRv4 as they are all inputs + if model.ir_version < 4: + return + + def add_const_value_infos_to_graph(graph : onnx.GraphProto): + inputs = {i.name for i in graph.input} + existing_info = {vi.name: vi for vi in graph.value_info} + for init in graph.initializer: + # Check it really is a constant, not an input + if init.name in inputs: + continue + + # The details we want to add + elem_type = init.data_type + shape = init.dims + + # Get existing or create new value info for this constant + vi = existing_info.get(init.name) + if vi is None: + vi = graph.value_info.add() + vi.name = init.name + + # Even though it would be weird, we will not overwrite info even if it doesn't match + tt = vi.type.tensor_type + if tt.elem_type == onnx.TensorProto.UNDEFINED: + tt.elem_type = elem_type + if not tt.HasField("shape"): + # Ensure we set an empty list if the const is scalar (zero dims) + tt.shape.dim.extend([]) + for dim in shape: + tt.shape.dim.add().dim_value = dim + + # Handle subgraphs + for node in graph.node: + for attr in node.attribute: + # Ref attrs refer to other attrs, so we don't need to do anything + if attr.ref_attr_name != "": + continue + + if attr.type == onnx.AttributeProto.GRAPH: + add_const_value_infos_to_graph(attr.g) + if attr.type == onnx.AttributeProto.GRAPHS: + for g in attr.graphs: + add_const_value_infos_to_graph(g) + + + return add_const_value_infos_to_graph(model.graph) + + +def main(): + global input_dir, result_dir + + modelname, batchsize, seqlen, \ + export_path, input_dir, result_dir = parse_args() + + model, voc_size = get_model(modelname) # pytorch model + + if export_path is not None: + filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen) + path = os.path.join(export_path, filename) + param = torch.zeros((batchsize, seqlen), dtype=torch.int) + export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model + + run_pytorch(model, voc_size, batchsize, seqlen, modelname) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/distributed/kunlun/run_pytorch.sh b/examples/distributed/kunlun/run_pytorch.sh new file mode 100644 index 00000000..b806ea03 --- /dev/null +++ b/examples/distributed/kunlun/run_pytorch.sh @@ -0,0 +1,17 @@ +export HF_ENDPOINT=https://hf-mirror.com + +models=("bert" "gpt2") +batch_size=(1 32) +seq_len=(100 500) +nproc=(1 2 4) + +for model in "${models[@]}"; do + for bs in "${batch_size[@]}"; do + for len in "${seq_len[@]}"; do + python -m xacc run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" > results/"$model"_"$bs"_"$len" + for n in "${nproc[@]}"; do + python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n >> results/"$model"_"$bs"_"$len" + done + done + done +done diff --git a/include/core/kernel.h b/include/core/kernel.h index 76189599..2e0d384a 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -5,8 +5,8 @@ #include "utils/operator_utils.h" #include #include -using json = nlohmann::json; namespace infini { +using json = nlohmann::json; class RuntimeObj; // Forward declaration for Kernel::compute diff --git a/include/core/perf_engine.h b/include/core/perf_engine.h index fb65da34..99a5916b 100644 --- a/include/core/perf_engine.h +++ b/include/core/perf_engine.h @@ -2,8 +2,8 @@ #include "core/graph.h" #include "core/kernel.h" #include -using json = nlohmann::json; namespace infini { +using json = nlohmann::json; class PerfEngine { public: diff --git a/src/kernels/kunlun/all_reduce.cc b/src/kernels/kunlun/all_reduce.cc index bbbe13a5..1b5c5985 100644 --- a/src/kernels/kunlun/all_reduce.cc +++ b/src/kernels/kunlun/all_reduce.cc @@ -20,14 +20,15 @@ class AllReduceXCCL : public KUNLUNKernelWithoutConfig { BKCLContext_t comm = dynamic_cast(context->getCommunicator()) .getXcclComm(); - double t = timeit( - [&]() { - checkXcclError(bkcl_all_reduce(comm, input, output, count, - BKCLDataType::BKCL_FLOAT, - getRedOp(), 0)); - }, - [&]() { context->sync(); }); - std::cout << "Time consuming for " << op->getInputs(0)->size() << " size is " << t << std::endl; + // double t = timeit( + // [&]() { + checkXcclError(bkcl_all_reduce(comm, input, output, count, + BKCLDataType::BKCL_FLOAT, getRedOp(), + 0)); + // }, + // [&]() { context->sync(); }); + // std::cout << "Time consuming for " << op->getInputs(0)->size() << " + // size is " << t << std::endl; } virtual BKCLOp getRedOp() const = 0; };