From e66f1c042152eee27e37469b8c32140fef675e1b Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Thu, 14 Dec 2023 18:01:03 +0800 Subject: [PATCH] fix: fix dist code to support fp16 --- examples/distributed/parallel_opt.py | 15 ++++++++++++--- src/kernels/cuda/all_reduce.cc | 11 ++++++++--- src/kernels/cuda/element_wise.cc | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 3ddf2ead..a11819cc 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -137,7 +137,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): place[node.output[0]] = Shard(list(perm).index(plc.dim)) def shard_node(node: NodeProto): - if node.op_type in ["Relu", "Tanh", "Softmax"]: + if node.name == "/model/layers.1/Add_1": + print(place[node.input[0]].is_shard(), place[node.input[1]].is_shard()) + if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]: place[node.output[0]] = place[node.input[0]] elif node.op_type in ["Where"]: place[node.output[0]] = place[node.input[1]] @@ -152,7 +154,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): elif node.op_type == "MatMul": assert ( place[node.input[0]] == place[node.input[1]] - ), f"{place[node.input[0]]} != {place[node.input[1]]}" + ), f"{place[node.input[0]]} != {place[node.input[1]]}, {place[node.input[0]].is_shard()}, {place[node.input[1]].is_shard()}" place[node.output[0]] = place[node.input[0]] elif node.op_type == "Concat": shard_concat(node) @@ -177,7 +179,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): input in data for input in node.input ): # FIXME(constroy): the last MatMul should not be sharded as TP. - if node.output[0] in output: + if ( + node.output[0] in output + or ( + index + 1 < len(model.graph.node) + and model.graph.node[index + 1].output[0] + ) + in output + ): continue groups = 1 # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc index 12e6557f..16cd6306 100644 --- a/src/kernels/cuda/all_reduce.cc +++ b/src/kernels/cuda/all_reduce.cc @@ -13,15 +13,20 @@ class AllReduceNCCL : public CudaKernelWithoutConfig { auto context = dynamic_cast(_context); void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); - IT_ASSERT(op->getDType() == DataType::Float32); + ncclDataType_t ncclType = ncclFloat; + if (op->getDType() == DataType::Float16) { + ncclType = ncclFloat16; + } else if (op->getDType() == DataType::Int8) { + ncclType = ncclInt8; + } size_t count = op->getInputs(0)->size(); ncclComm_t comm = dynamic_cast(context->getCommunicator()) .getNcclComm(); // TODO: Using default stream 0 for now. - checkNcclError(ncclAllReduce(input, output, count, ncclFloat, - getRedOp(), comm, 0)); + checkNcclError( + ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0)); } virtual ncclRedOp_t getRedOp() const = 0; diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 18dd967c..4a16de29 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -65,7 +65,7 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig { cudnnOpTensorDescriptor_t opDesc; checkCudnnError(cudnnCreateOpTensorDescriptor(&opDesc)); checkCudnnError(cudnnSetOpTensorDescriptor( - opDesc, getOpType(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN)); + opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN)); auto [aAlpha, bAlpha, beta] = getAlphBeta();