From feccd4f318d58d019508c60026de6f1fbee91bda Mon Sep 17 00:00:00 2001 From: constroy Li Date: Mon, 30 Oct 2023 15:04:16 +0800 Subject: [PATCH] fix tensor parallel for llama (#159) * fix Slice * change default rounds of timeit to 10 to reduce time * fix slice with large ends * Reshape support Int64 * support position_ids as input * skip last MatMul in Llama * skip infer_shapes to parse large model * update launch.py * fix split_concat_kernel * print more message in launch.py * Reshape supports both Int32 and Int64 * try infer_shapes and warn about failure * fix format --------- Co-authored-by: whjthu --- examples/distributed/launch.py | 43 +++++++++++++---------- examples/distributed/parallel_opt.py | 10 ++++-- include/core/common.h | 4 +-- pyinfinitensor/src/pyinfinitensor/onnx.py | 19 +++++++--- src/kernels/cuda/matmul.cc | 1 + src/kernels/cuda/reshape.cc | 4 +++ src/kernels/cuda/split_concat.cu | 8 +++-- src/operators/concat.cc | 6 ++-- src/operators/slice.cc | 15 +++++--- 9 files changed, 70 insertions(+), 40 deletions(-) diff --git a/examples/distributed/launch.py b/examples/distributed/launch.py index 64930e6e..58f7efb3 100644 --- a/examples/distributed/launch.py +++ b/examples/distributed/launch.py @@ -5,6 +5,7 @@ import multiprocessing as mp from pyinfinitensor.onnx import OnnxStub, backend import onnx from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path import numpy as np from parallel_opt import parallel_model @@ -44,16 +45,18 @@ def parse_args(): ) -def run_model(model, runtime, inputs: np.array, n=20): +def run_model(model, runtime, inputs, n=10): stub = OnnxStub(model, runtime) - next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) - stub.tune() + for tensor, input in zip(stub.inputs.values(), inputs): + tensor.copyin_numpy(input) + # stub.tune() stub.run() # get outputs - outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() # bench - next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) + for tensor, input in zip(stub.inputs.values(), inputs): + tensor.copyin_numpy(input) begin = time.time() for _ in range(n): stub.run() @@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20): def run_and_compare(name, model, runtime): - data = np.load(f"{name}_inputs.npy") + input_ids = np.load(f"{name}_inputs.npy") + position_ids = np.arange(input_ids.shape[-1]) results = np.load(f"{name}_results.npy") - outputs = run_model(model, runtime, data) - print("outputs sum:", outputs.sum()) - print("max abs diff:", abs(outputs - results).max()) - print("max rel diff:", abs((outputs - results) / results).max()) - # assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6) + outputs = run_model(model, runtime, (input_ids, position_ids)) + print("outputs abs mean:", abs(outputs).mean()) + np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3) def start_worker( @@ -81,14 +83,13 @@ def start_worker( extern_path = f"./{dist_name}_rank{rank}.pb" if os.path.exists(extern_path): os.remove(extern_path) - convert_model_to_external_data( + onnx.save_model( model, - all_tensors_to_one_file=True, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, location=extern_path, - size_threshold=1024, - convert_attribute=False, ) - onnx.save(model, f"./{dist_name}_rank{rank}.onnx") + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") runtime = backend.CudaRuntime(local_rank) # print("init comm") runtime.init_comm( @@ -106,10 +107,12 @@ def start_single(name, model): def gen_standard(name, model, voc_size, bs, len): # generate standard results - data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) - np.save(f"{name}_inputs", data) + input_ids = np.random.randint(0, voc_size, (bs, len)) + position_ids = np.arange(len) + np.save(f"{name}_inputs", input_ids) runtime = backend.CudaRuntime(0) - outputs = run_model(model, runtime, data, 1) + outputs = run_model(model, runtime, (input_ids, position_ids), 1) + print("outputs abs mean:", abs(outputs).mean()) np.save(f"{name}_results", outputs) @@ -128,12 +131,14 @@ def main(): # run single process. # use standalone process to isolate cuda. + print("run model by single GPU.") p = mp.Process(target=start_single, args=(name, model)) p.start() p.join() # run distributed parallel. world_size = nnodes * nproc_per_node + print(f"run model by {world_size} GPU in parallel.") workers = [ mp.Process( target=start_worker, diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 42465a69..3ddf2ead 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): vinfo = {info.name: info for info in model.graph.value_info} vinfo.update({info.name: info for info in model.graph.input}) vinfo.update({info.name: info for info in model.graph.output}) + output = {info.name: info for info in model.graph.output} place: Dict[str, Placement] = {} nodes: List[NodeProto] = [] @@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): ndim = len(vinfo[output].type.tensor_type.shape.dim) out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() place[node.output[0]] = out_plc - + def shard_concat(node: NodeProto): # hack for kvcache in_plc = place[node.input[1]] @@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): ), f"{place[node.input[0]]} != {place[node.input[1]]}" place[node.output[0]] = place[node.input[0]] elif node.op_type == "Concat": - shard_concat(node) + shard_concat(node) def find_successor(op_type: str, idx: int, search_limit: int = 1): for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: @@ -175,6 +176,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( input in data for input in node.input ): + # FIXME(constroy): the last MatMul should not be sharded as TP. + if node.output[0] in output: + continue groups = 1 # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups split_node = find_successor("Split", index, search_limit=2) @@ -218,7 +222,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): new_input = [] for info in model.graph.input: new_input.append(vinfo[info.name]) - + graph = helper.make_graph( nodes, model.graph.name + f"_{tp_rank}", diff --git a/include/core/common.h b/include/core/common.h index 749caff2..81e704f8 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -75,7 +75,7 @@ template std::string vecToString(const std::vector &vec) { double timeit( const std::function &func, - const std::function &sync = []() {}, int warmupRounds = 200, - int timingRounds = 200); + const std::function &sync = []() {}, int warmupRounds = 10, + int timingRounds = 10); } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index d11fbb90..6d0da9f8 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -28,6 +28,7 @@ from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from functools import reduce from onnxsim import simplify import copy +import warnings class OnnxStub: @@ -48,7 +49,10 @@ class OnnxStub: self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {} - model = infer_shapes(model) + try: + model = infer_shapes(model) + except: + warnings.warn("infer_shapes failed.") self.handler = backend.GraphHandler(runtime) tensors: Dict[str, backend.Tensor] = dict() @@ -603,15 +607,20 @@ class OnnxStub: != 0, ) elif node.op_type == "Slice": + + def clamp(nums): + MAX_INT = 0x7FFFFFFF + return [min(x, MAX_INT) for x in nums] + tensors[node.output[0]] = self.handler.slice( tensors[node.input[0]], tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[2]]), - _parse_data(data[node.input[3]]) + clamp(_parse_data(data[node.input[1]])), + clamp(_parse_data(data[node.input[2]])), + clamp(_parse_data(data[node.input[3]])) if len(node.input) > 3 else None, - _parse_data(data[node.input[4]]) + clamp(_parse_data(data[node.input[4]])) if len(node.input) > 4 else None, ) diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index 9cd4b0b3..2d457cbc 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -58,6 +58,7 @@ class matmulCublas : public Kernel { SmallArray inputShape, outputShape; int nDims = out->getRank(); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + // FIXME(constroy): use size_t for outputsize. int outputsize = 1; // the length of the output vector after flatten int offset = nDims - inC->getRank(); for (int i = 0; i < offset; ++i) diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index 77070c23..7be6aca8 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -13,6 +13,10 @@ class CopyCuda : public CudaKernelWithoutConfig { // reshape/flatten/identity all act as copying from input to output. REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda, "Reshape_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda, + "Reshape_CUDA_Int64"); +REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda, + "Reshape_CUDA_Int32"); REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, "Flatten_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index 73f29482..193501e0 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -51,13 +51,15 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, namespace infini { -// TODO: when dim=0, the operation can be executed in-place +// TODO: when dim=0, the operation can be executed in-place void split_concat_kernel(const ElementTensorMetadata &eleMeta, const ComposedTensorMetadata &compMeta, int dim, int batchSize, int nDims, bool isSplit) { dim3 blockSize = dim3(32 * 16); - // gridsize =n_elements / blockSize - int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1; + // gridsize = max_n_elements / blockSize + int max_n_elements = + *std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize); + int gridDimX = (max_n_elements - 1) / (32 * 16) + 1; // each y is a split among the batch dim3 gridSize(gridDimX, batchSize); diff --git a/src/operators/concat.cc b/src/operators/concat.cc index de836d58..95535233 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -2,10 +2,10 @@ #include "utils/operator_utils.h" namespace infini { -ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) - : OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { +ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) + : OperatorObj(OpType::Concat, inputs, {output}) { int rank = inputs[0]->getRank(); - dim = get_real_axis(dim, rank); + dim = get_real_axis(_dim, rank); IT_ASSERT(checkValid(graph)); } diff --git a/src/operators/slice.cc b/src/operators/slice.cc index 1ded2745..0db3b1a2 100644 --- a/src/operators/slice.cc +++ b/src/operators/slice.cc @@ -43,17 +43,22 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output, auto size = shape.size(); this->axes.reserve(size); - for (size_t i = 0; i < size; ++i) + for (size_t i = 0; i < size; ++i) { + auto len = shape[i]; if (auto _i = axes.find(i); _i != axes.end()) { auto __i = _i->second; auto start = starts[__i]; auto end = ends[__i]; - this->axes.push_back({start >= 0 ? start : start + shape[__i], - end >= 0 ? end : end + shape[__i], - steps[__i]}); + if (start > len) + start = len; + if (end > len) + end = len; + this->axes.push_back({start >= 0 ? start : start + len, + end >= 0 ? end : end + len, steps[__i]}); } else { - this->axes.push_back({0, shape[i], 1}); + this->axes.push_back({0, len, 1}); } + } IT_ASSERT(checkValid(graph)); }