fix: fix dist code to support fp16

This commit is contained in:
zhangyunze 2023-12-14 18:01:03 +08:00
parent ff98241db7
commit e66f1c0421
3 changed files with 21 additions and 7 deletions

View File

@ -137,7 +137,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
place[node.output[0]] = Shard(list(perm).index(plc.dim))
def shard_node(node: NodeProto):
if node.op_type in ["Relu", "Tanh", "Softmax"]:
if node.name == "/model/layers.1/Add_1":
print(place[node.input[0]].is_shard(), place[node.input[1]].is_shard())
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
place[node.output[0]] = place[node.input[0]]
elif node.op_type in ["Where"]:
place[node.output[0]] = place[node.input[1]]
@ -152,7 +154,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
elif node.op_type == "MatMul":
assert (
place[node.input[0]] == place[node.input[1]]
), f"{place[node.input[0]]} != {place[node.input[1]]}"
), f"{place[node.input[0]]} != {place[node.input[1]]}, {place[node.input[0]].is_shard()}, {place[node.input[1]].is_shard()}"
place[node.output[0]] = place[node.input[0]]
elif node.op_type == "Concat":
shard_concat(node)
@ -177,7 +179,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
input in data for input in node.input
):
# FIXME(constroy): the last MatMul should not be sharded as TP.
if node.output[0] in output:
if (
node.output[0] in output
or (
index + 1 < len(model.graph.node)
and model.graph.node[index + 1].output[0]
)
in output
):
continue
groups = 1
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups

View File

@ -13,15 +13,20 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
ncclDataType_t ncclType = ncclFloat;
if (op->getDType() == DataType::Float16) {
ncclType = ncclFloat16;
} else if (op->getDType() == DataType::Int8) {
ncclType = ncclInt8;
}
size_t count = op->getInputs(0)->size();
ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
.getNcclComm();
// TODO: Using default stream 0 for now.
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
getRedOp(), comm, 0));
checkNcclError(
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
}
virtual ncclRedOp_t getRedOp() const = 0;

View File

@ -65,7 +65,7 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
cudnnOpTensorDescriptor_t opDesc;
checkCudnnError(cudnnCreateOpTensorDescriptor(&opDesc));
checkCudnnError(cudnnSetOpTensorDescriptor(
opDesc, getOpType(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN));
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
auto [aAlpha, bAlpha, beta] = getAlphBeta();