From 4c321c8a91120a7277944a72b6631c3d03951a60 Mon Sep 17 00:00:00 2001 From: constroy Li Date: Thu, 14 Sep 2023 14:19:45 +0800 Subject: [PATCH] tensor parallel for transformer (#125) * add cmake bits about NCCL * move example to examples/NNmodel * impl NCCL communicator * add comm related function to Runtime * export runtime interface * add launch.py * use unique name to distingush the the NCCL ID file * add timeout to communicator init * expose communicator obj from runtime obj, add unit test for nccl communicator * reformat files * Add allReduce operator and cuda nccl allReduce kernel * impl model parallel for resnet * add allGather nccl kernel and operator * Add allreduce allgather operator tests, change allgather kernel to output list of tensor, fix shape infer, handle nullptr output * fix format of onnx.py * use concat following AllGather * get tensor parallel for resnet * fix format of graph_handler.cc * change BUILD_DIST default to OFF * polish code of communicator * update .gitignore * export min/max to python * fix MatMul * modify launch.py to run opt * hack to treat ReduceSum as AllReduceSum * throw exception in cuda error * fix parallel_opt.py * improve the error prompt and cuda error check * fix GatherObj::GatherObj member init * fix size calculation for scalar (rank = 0) tensor * MatMul supports bias * fix add bias for row parallel gemm * add --gen_std to launch.py * fix AllReduceNCCL * update launch.py * less log * update parallel_opt * update launch.py * add __eq__ for Placement sub-classes * less benchmark run * fix placement infer for matmul * fix vacabuary size * fix Exception * Add shard tensor with group to support gpt2 * Add find successor function to find split op at different depth * recover CommunicatorObj * improve error mesasge * optimize parallel_opt.py * optimize launch.py * recover docs for all_reduce and all_gather * Fix API * fix format --------- Co-authored-by: panzezhong Co-authored-by: Haojie Wang --- examples/distributed/launch.py | 127 +++++++++---- examples/distributed/parallel_opt.py | 221 ++++++++++++++++++++++ examples/distributed/placement.py | 64 +++++++ include/core/common.h | 12 +- include/cuda/cuda_common.h | 25 +-- include/utils/exception.h | 10 + include/utils/small_array.h | 1 + pyinfinitensor/src/pyinfinitensor/onnx.py | 17 +- pyinfinitensor/tests/test_onnx.py | 17 +- src/cuda/cuda_runtime.cc | 12 +- src/ffi/ffi_infinitensor.cc | 2 + src/kernels/cuda/all_reduce.cc | 2 +- src/kernels/cuda/matmul.cc | 27 ++- src/operators/gather.cc | 4 +- src/utils/exception.cc | 3 +- 15 files changed, 454 insertions(+), 90 deletions(-) create mode 100644 examples/distributed/parallel_opt.py create mode 100644 examples/distributed/placement.py diff --git a/examples/distributed/launch.py b/examples/distributed/launch.py index 362bde1c..64930e6e 100644 --- a/examples/distributed/launch.py +++ b/examples/distributed/launch.py @@ -4,8 +4,12 @@ import time import multiprocessing as mp from pyinfinitensor.onnx import OnnxStub, backend import onnx +from onnx.external_data_helper import convert_model_to_external_data import numpy as np -from parallel import parallel_model +from parallel_opt import parallel_model + + +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" def parse_args(): @@ -14,77 +18,126 @@ def parse_args(): parser.add_argument( "--nproc_per_node", type=int, default=1, help="number of processes per node" ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) parser.add_argument( "--model", type=str, required=True, help="path to the ONNX model file." ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + action="store_true", + help="whether to generate the standard results.", + ) args = parser.parse_args() print("arg setting: ", args) - return args.num_nodes, args.nproc_per_node, args.model + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + ) -def run_stub(stub: OnnxStub, inputs: np.array, n=100): - # warm up - next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist()) +def run_model(model, runtime, inputs: np.array, n=20): + stub = OnnxStub(model, runtime) + next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) stub.tune() - for _ in range(20): - stub.run() + stub.run() + # get outputs outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) # bench - next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist()) + next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) begin = time.time() for _ in range(n): stub.run() end = time.time() - outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) - print("outputs sum:", outputs.sum()) - # np.save("results", outputs) - results = np.load("results.npy") - print("max diff:", abs(outputs - results).max()) - assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6) avg_time = (end - begin) / n - return avg_time + print(f"average time: {avg_time}") + return outputs + + +def run_and_compare(name, model, runtime): + data = np.load(f"{name}_inputs.npy") + results = np.load(f"{name}_results.npy") + outputs = run_model(model, runtime, data) + print("outputs sum:", outputs.sum()) + print("max abs diff:", abs(outputs - results).max()) + print("max rel diff:", abs((outputs - results) / results).max()) + # assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6) def start_worker( - dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto ): - print("start worker") + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + model, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(model, f"./{dist_name}_rank{rank}.onnx") runtime = backend.CudaRuntime(local_rank) - print("init comm") + # print("init comm") runtime.init_comm( dist_name, world_size, rank, ) - model = parallel_model(model, world_size, rank) - onnx.save(model, f"dist_model_rank{rank}.onnx") - print("load model") - stub = OnnxStub(model, runtime) - data = np.load("inputs.npy") - print("run model") - avg_time = run_stub(stub, data) - print(f"average time: {avg_time}") + run_and_compare(name, model, runtime) + + +def start_single(name, model): + runtime = backend.CudaRuntime(0) + run_and_compare(name, model, runtime) + + +def gen_standard(name, model, voc_size, bs, len): + # generate standard results + data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) + np.save(f"{name}_inputs", data) + runtime = backend.CudaRuntime(0) + outputs = run_model(model, runtime, data, 1) + np.save(f"{name}_results", outputs) def main(): - nnodes, nproc_per_node, model_path = parse_args() - world_size = nnodes * nproc_per_node + nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() model = onnx.load(model_path) - # generate standard results - # runtime = backend.CudaRuntime(0) - # stub = OnnxStub(model, runtime) - # data = np.random.randn(1, 3, 224, 224) - # np.save("inputs", data) - # run_stub(stub, data) - # del stub - dist_name = f"dist_{os.getpid()}" + # generate standart output + if gen_std: + print(f"generate standard data for {name}.") + # a small vocabulary size to fit all LLM. + voc_size = 1000 + gen_standard(name, model, voc_size, bs, length) + return + + # run single process. + # use standalone process to isolate cuda. + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node workers = [ mp.Process( target=start_worker, - args=(dist_name, world_size, rank, rank % nproc_per_node, model), + args=(name, world_size, rank, rank % nproc_per_node, model), ) for rank in range(world_size) ] diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py new file mode 100644 index 00000000..c152f6be --- /dev/null +++ b/examples/distributed/parallel_opt.py @@ -0,0 +1,221 @@ +import onnx +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto +from onnx import helper, numpy_helper +from typing import Dict, List +from placement import Placement, Replicate, Shard, _Partial +import numpy as np + + +def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): + data = {init.name: init for init in model.graph.initializer} + vinfo = {info.name: info for info in model.graph.value_info} + vinfo.update({info.name: info for info in model.graph.input}) + vinfo.update({info.name: info for info in model.graph.output}) + place: Dict[str, Placement] = {} + nodes: List[NodeProto] = [] + + def is_sharded(name: str): + return place[name].is_shard() + + def shard_tensor(tensor: TensorProto, plc: Shard, groups: int = 1): + # print(f"shard {tensor.name} at dim {dim}") + assert plc.is_shard(), plc + ndim = len(tensor.dims) + if plc.dim < 0: + plc.dim += ndim + if tensor.dims[plc.dim] == 1: # broadcast dim, no need to shard. + return tensor + array = numpy_helper.to_array(tensor) + assert array.shape[plc.dim] % tp_world_size == 0, array.shape[plc.dim] + dims = list(tensor.dims) + dims.insert(plc.dim, groups) + dims[plc.dim + 1] //= groups + array = array.reshape(dims) + seg = array.shape[plc.dim + 1] // tp_world_size + array = array.take( + indices=range(tp_rank * seg, (tp_rank + 1) * seg), axis=plc.dim + 1 + ) + dims = list(tensor.dims) + dims[plc.dim] //= tp_world_size + array = array.reshape(dims) + tensor = numpy_helper.from_array(array, name=tensor.name) + place[tensor.name] = plc + return tensor + + def shard_gemm(node: NodeProto, groups: int = 1): + # print("gemm", node.name) + in_plc = place[node.input[0]] + w_plc = Shard(-1) if in_plc.is_replicate() else Shard(0) + transB = next((attr.i for attr in node.attribute if attr.name == "transB"), 0) + if transB: + w_plc.dim = ~w_plc.dim + input = node.input[1] + data[input] = shard_tensor(data[input], w_plc, groups) + + output = node.output[0] + ndim = len(vinfo[output].type.tensor_type.shape.dim) + out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() + place[node.output[0]] = out_plc + + def shard_binary(node: NodeProto, groups: int = 1): + # print("binary", node.name, node.input[0], place[node.input[0]]) + a = node.input[0] + b = node.input[1] + if a in data: + a, b = b, a + place[node.output[0]] = place[a] + if is_sharded(a) and b in data and len(data[b].dims) == 1: # broadcast + data[b] = shard_tensor(data[b], Shard(0), groups) + + def shard_reshape(node: NodeProto): + # print("reshape", node.name, node.input[0], place[node.input[0]]) + if not is_sharded(node.input[0]): + return + in_plc = place[node.input[0]] + s_dim = -1 + in_dims = [d.dim_value for d in vinfo[node.input[0]].type.tensor_type.shape.dim] + tensor = data[node.input[1]] + out_dims = numpy_helper.to_array(tensor).copy() + if len(in_dims) == 3 and len(out_dims) == 4: + if in_plc.dim == 0: + s_dim = 1 + elif in_plc.dim == 2: + s_dim = 2 + if len(in_dims) == 4 and len(out_dims) == 3: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 2 + if len(in_dims) == 2 and len(out_dims) == 3: + if in_plc.dim == 1: + s_dim = 2 + if len(in_dims) == 4 and len(out_dims) == 2: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 1 + if len(in_dims) == 3 and len(out_dims) == 2: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 1 + + assert s_dim != -1 + assert out_dims[s_dim] % tp_world_size == 0, out_dims + out_dims[s_dim] //= tp_world_size + # if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others. + # node.input[1] = node.output[0] + "_shape" + data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1]) + place[node.output[0]] = Shard(s_dim) + + def shard_split(node: NodeProto): + if not is_sharded(node.input[0]): + return + in_plc = place[node.input[0]] + split_tensor = data[node.input[1]] + split = numpy_helper.to_array(split_tensor).copy() + split //= tp_world_size + data[node.input[1]] = numpy_helper.from_array(split, name=node.input[1]) + for output in node.output: + place[output] = in_plc + + def shard_transpose(node: NodeProto): + plc = place[node.input[0]] + if plc.is_shard(): + perm = next(attr.ints for attr in node.attribute if attr.name == "perm") + place[node.output[0]] = Shard(list(perm).index(plc.dim)) + + def shard_node(node: NodeProto): + if node.op_type in ["Relu", "Tanh", "Softmax"]: + place[node.output[0]] = place[node.input[0]] + elif node.op_type in ["Where"]: + place[node.output[0]] = place[node.input[1]] + if node.op_type in {"Add", "Mul", "Div", "Max"}: + shard_binary(node) + elif node.op_type == "Reshape": + shard_reshape(node) + elif node.op_type == "Transpose": + shard_transpose(node) + elif node.op_type == "Split": + shard_split(node) + elif node.op_type == "MatMul": + assert ( + place[node.input[0]] == place[node.input[1]] + ), f"{place[node.input[0]]} != {place[node.input[1]]}" + place[node.output[0]] = place[node.input[0]] + + def find_successor(op_type: str, idx: int, search_limit: int = 1): + for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: + if node.op_type == op_type: + return node + return None + + # all tensors are initially replicated. + for v in vinfo: + place[v] = Replicate() + + for t in data: + place[t] = Replicate() + + for index, node in enumerate(model.graph.node): + nodes.append(node) + # linear + if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( + input in data for input in node.input + ): + groups = 1 + # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups + split_node = find_successor("Split", index, search_limit=2) + if split_node is not None: + groups = len(split_node.output) + shard_gemm(node, groups) + plc = place[node.output[0]] + if plc.is_partial(): + new_name = node.output[0] + f":{plc}" + place[new_name] = place[node.output[0]] + # insert all_reduce + nodes.append( + helper.make_node( + op_type="ReduceSum", + inputs=[new_name], + outputs=[node.output[0]], + name=node.name + "/all_reduce", + noop_with_empty_axes=1, + communicator=0, # hack to treat ReduceSum as AllReduceSum + ) + ) + place[node.output[0]] = Replicate() + node.output[0] = new_name + if len(node.input) > 2: # split bias to add + prev = nodes[-1] + new_name = prev.output[0] + "_no_bias" + place[new_name] = place[node.output[0]] + bias = helper.make_node( + op_type="Add", + inputs=[new_name, node.input[2]], + outputs=[prev.output[0]], + name=node.name + "/bias", + ) + node.input.pop() + prev.output[0] = new_name + shard_binary(bias, groups) + nodes.append(bias) + continue + shard_node(node) + + graph = helper.make_graph( + nodes, + model.graph.name + f"_{tp_rank}", + model.graph.input, + model.graph.output, + data.values(), + doc_string=model.graph.doc_string, + # value_info=vinfo.values(), + ) + for output in graph.output: + tt = output.type.tensor_type + if tt.HasField("shape"): + tt.ClearField("shape") + model = helper.make_model(graph) + model = onnx.shape_inference.infer_shapes(model) + return model diff --git a/examples/distributed/placement.py b/examples/distributed/placement.py new file mode 100644 index 00000000..634d4fe5 --- /dev/null +++ b/examples/distributed/placement.py @@ -0,0 +1,64 @@ +from typing import Optional + + +class Placement: + # base class Placement type + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + if dim is not None and isinstance(self, Shard): + return self.dim == dim + else: + return isinstance(self, Shard) + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, _Partial) + + +class Replicate(Placement): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Replicate): + return False + return True + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + +class Shard(Placement): + # shard placement, shard on a dim + def __init__(self, dim): + self.dim = dim + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + +class _Partial(Placement): + def __init__(self, reduce_op: str = "sum"): + self.reduce_op: str = reduce_op + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Partial): + return False + return self.reduce_op == other.reduce_op + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"_Partial(reduce_op={self.reduce_op})" diff --git a/include/core/common.h b/include/core/common.h index caa61084..749caff2 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -40,12 +40,12 @@ using HashType = uint64_t; // compatible with std::hash // Assert: conditions should have no side effect #define _IT_ASSERT_2(condition, info) \ - (static_cast(condition) \ - ? void(0) \ - : throw ::infini::Exception( \ - std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ - "] Assertion failed (" + #condition + "): " + info)) -#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, ""); + static_cast(condition) \ + ? void(0) \ + : throw ::infini::Exception( \ + std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ + "] Assertion failed (" + #condition + "): " + info) +#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "") #define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) #define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented") diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index dec9a40b..4eb75f27 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -6,16 +6,11 @@ #include #include -// TODO: replace with Exception (IT_ASSERT) #define checkCudaError(call) \ - { \ - auto err = call; \ - if (cudaSuccess != err) { \ - fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \ - cudaGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ - } + if (auto err = call; err != cudaSuccess) \ + throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \ + std::to_string(__LINE__) + "] CUDA error (" + \ + #call + "): " + cudaGetErrorString(err)) #define checkCUresult(call) \ { \ @@ -39,14 +34,10 @@ } #define checkCudnnError(call) \ - { \ - auto err = call; \ - if (CUDNN_STATUS_SUCCESS != err) { \ - fprintf(stderr, "cuDNN error in %s:%i : %s.\n", __FILE__, \ - __LINE__, cudnnGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ - } + if (auto err = call; err != CUDNN_STATUS_SUCCESS) \ + throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \ + std::to_string(__LINE__) + "] cuDNN error (" + \ + #call + "): " + cudnnGetErrorString(err)) #define checkCurandError(call) \ { \ diff --git a/include/utils/exception.h b/include/utils/exception.h index f2a7dc94..d7bb4331 100644 --- a/include/utils/exception.h +++ b/include/utils/exception.h @@ -5,8 +5,18 @@ namespace infini { class Exception : public std::runtime_error { + protected: + std::string info; + public: Exception(const std::string &msg); + + Exception &operator<<(const std::string &str) { + info += str; + return *this; + } + + const char *what() const noexcept override { return info.c_str(); } }; } // namespace infini diff --git a/include/utils/small_array.h b/include/utils/small_array.h index d0e29a09..3ea93279 100644 --- a/include/utils/small_array.h +++ b/include/utils/small_array.h @@ -1,3 +1,4 @@ +#pragma once namespace infini { #define SMALL_ARRAY_SIZE 8 diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index b0a433ba..bade70f8 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -591,6 +591,13 @@ class OnnxStub: tensors.get(node.output[0]), next((attr.i for attr in node.attribute if attr.name == "to")), ) + elif node.op_type == "ReduceSum": + # ReduceSum is only implemented as allReduceSum. + assert any(attr.name == "communicator" for attr in node.attribute) + tensors[node.output[0]] = self.handler.allReduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "AllReduceSum": tensors[node.output[0]] = self.handler.allReduceSum( tensors[node.input[0]], @@ -631,13 +638,9 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), next( - ( - attr.i - for attr in node.attribute - if attr.name == "root" - ), - 0, - ), + (attr.i for attr in node.attribute if attr.name == "root"), + 0, + ), ) elif node.op_type == "Expand": shape = _parse_data(data[node.input[1]]) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 884bd874..c5ff13ee 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -329,7 +329,7 @@ class TestStringMethods(unittest.TestCase): [pads_data], ) ) - + def test_allReduceSum(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) @@ -349,7 +349,7 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([allReduceProd], "allReduceProd", [input], [output]) model = make_model(graph) from_onnx(model, backend.cpu_runtime()) - + def test_allReduceMin(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) @@ -379,14 +379,12 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output]) model = make_model(graph) from_onnx(model, backend.cpu_runtime()) - + def test_split(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) - split = make_node( - "Split", ["input"], ["output"], name="split", axis=0 - ) + split = make_node("Split", ["input"], ["output"], name="split", axis=0) make_and_import_model(make_graph([split], "split", [input], [])) - + def test_allBroadcast(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) @@ -461,7 +459,7 @@ class TestStringMethods(unittest.TestCase): make_and_import_model(make_graph([where], "where", [x, y, con], [output])) def test_copyin(self): - dims = [2,3,5,4] + dims = [2, 3, 5, 4] np_array = np.random.random(dims).astype(np.float32) handler = backend.GraphHandler(backend.cpu_runtime()) tensor1 = handler.tensor(dims, TensorProto.FLOAT) @@ -487,7 +485,7 @@ class TestStringMethods(unittest.TestCase): self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) def test_to_numpy(self): - dims = [2,3,5,4] + dims = [2, 3, 5, 4] np_array = np.random.random(dims).astype(np.float32) handler = backend.GraphHandler(backend.cpu_runtime()) tensor1 = handler.tensor(dims, TensorProto.FLOAT) @@ -508,5 +506,6 @@ class TestStringMethods(unittest.TestCase): array1 = np.array(tensor1, copy=False) self.assertTrue(np.array_equal(array1, np_array)) + if __name__ == "__main__": unittest.main() diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 927b1f0d..0676646a 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -8,7 +8,6 @@ #include "operators/conv.h" #include "operators/matmul.h" -#ifdef DEBUG_MODE void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) { cudaError_t kernelError = cudaGetLastError(); if (kernelError != cudaSuccess) { @@ -18,7 +17,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) { exit(EXIT_FAILURE); } } -#endif namespace infini { @@ -38,10 +36,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { } else { kernel->compute(op, this); } - -#ifdef DEBUG_MODE - CHECK_CUDA_KERNEL_ERROR(op); -#endif + checkCudaError(cudaGetLastError()) << op->toString(); } } @@ -78,9 +73,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { opCnt[op->getOpType()]++; } -#ifdef DEBUG_MODE - CHECK_CUDA_KERNEL_ERROR(op); -#endif + checkCudaError(cudaGetLastError()) << op->toString(); } } @@ -103,6 +96,7 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) { IT_ASSERT(worldSize > 0); IT_ASSERT(rank >= 0); IT_ASSERT(rank < worldSize); + IT_ASSERT(!comm) << "communicator is already initialized."; #ifdef INFINI_USE_NCCL comm = std::make_unique(name, worldSize, rank); #else diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index f6af18ec..5cc9717f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -421,6 +421,8 @@ void init_graph_builder(py::module &m) { .def("mul", &Handler::mul, policy::move) .def("div", &Handler::div, policy::move) .def("pow", &Handler::pow, policy::move) + .def("min", &Handler::min, policy::move) + .def("max", &Handler::max, policy::move) .def("relu", &Handler::relu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) .def("tanh", &Handler::tanh, policy::move) diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc index 2728b5e2..ef60b991 100644 --- a/src/kernels/cuda/all_reduce.cc +++ b/src/kernels/cuda/all_reduce.cc @@ -14,7 +14,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig { void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); IT_ASSERT(op->getDType() == DataType::Float32); - size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + size_t count = op->getInputs(0)->size(); ncclComm_t comm = dynamic_cast(context->getCommunicator()) diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index a2b55e04..9cd4b0b3 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -1,6 +1,8 @@ #include "operators/matmul.h" #include "core/kernel.h" +#include "cuda/cuda_expand.h" #include "cuda/cuda_runtime.h" +#include "utils/small_array.h" namespace infini { @@ -46,7 +48,30 @@ class matmulCublas : public Kernel { auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N; const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n, ldc = n; - const float alpha = 1.f, beta = 0.f; + float alpha = 1.f, beta = 0.f; + if (op->numInputs() == 2) { // no bias + beta = 0.f; + } else { // broadcast bias to output + beta = 1.f; + auto inC = op->getInputs(2); + auto out = op->getOutput(); + SmallArray inputShape, outputShape; + int nDims = out->getRank(); + IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + int outputsize = 1; // the length of the output vector after flatten + int offset = nDims - inC->getRank(); + for (int i = 0; i < offset; ++i) + inputShape.data[i] = 1; + for (int i = 0; i < nDims; ++i) { + outputShape.data[i] = out->getDims()[i]; + outputsize *= outputShape.data[i]; + if (i >= offset) + inputShape.data[i] = inC->getDims()[i - offset]; + } + expandKernel(inC->getRawDataPtr(), + out->getRawDataPtr(), nDims, outputsize, + inputShape, outputShape); + } // TODO:use compute type cublasStatus_t stat; if (b > 1) { diff --git a/src/operators/gather.cc b/src/operators/gather.cc index f615faf7..96493323 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -6,7 +6,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, Tensor output, int axis) : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) { int rank = input->getRank(); - axis = get_real_axis(axis, rank); + this->axis = get_real_axis(axis, rank); IT_ASSERT(checkValid(graph)); } @@ -25,7 +25,7 @@ optional> GatherObj::inferShape(const TensorVec &inputs) const { vector GatherObj::inferDataType(const TensorVec &inputs) const { IT_ASSERT(inputs.size() == 2); auto index_dtype = inputs[1]->getDType(); - IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64) + IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64); return {inputs[0]->getDType()}; } diff --git a/src/utils/exception.cc b/src/utils/exception.cc index 2121b3a7..a9dd87fb 100644 --- a/src/utils/exception.cc +++ b/src/utils/exception.cc @@ -9,7 +9,8 @@ namespace backward_trace = backward; backward_trace::SignalHandling sh; namespace infini { -Exception::Exception(const std::string &msg) : std::runtime_error(msg) { +Exception::Exception(const std::string &msg) + : std::runtime_error(msg), info(msg) { backward_trace::StackTrace st; st.load_here(32); backward_trace::Printer p;