forked from jiuyuan/InfiniTensor
fix: fix dist code to support fp16
This commit is contained in:
parent
ff98241db7
commit
e66f1c0421
|
@ -137,7 +137,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||
|
||||
def shard_node(node: NodeProto):
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax"]:
|
||||
if node.name == "/model/layers.1/Add_1":
|
||||
print(place[node.input[0]].is_shard(), place[node.input[1]].is_shard())
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type in ["Where"]:
|
||||
place[node.output[0]] = place[node.input[1]]
|
||||
|
@ -152,7 +154,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
elif node.op_type == "MatMul":
|
||||
assert (
|
||||
place[node.input[0]] == place[node.input[1]]
|
||||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||
), f"{place[node.input[0]]} != {place[node.input[1]]}, {place[node.input[0]].is_shard()}, {place[node.input[1]].is_shard()}"
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type == "Concat":
|
||||
shard_concat(node)
|
||||
|
@ -177,7 +179,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
input in data for input in node.input
|
||||
):
|
||||
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||
if node.output[0] in output:
|
||||
if (
|
||||
node.output[0] in output
|
||||
or (
|
||||
index + 1 < len(model.graph.node)
|
||||
and model.graph.node[index + 1].output[0]
|
||||
)
|
||||
in output
|
||||
):
|
||||
continue
|
||||
groups = 1
|
||||
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||
|
|
|
@ -13,15 +13,20 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
|||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
ncclDataType_t ncclType = ncclFloat;
|
||||
if (op->getDType() == DataType::Float16) {
|
||||
ncclType = ncclFloat16;
|
||||
} else if (op->getDType() == DataType::Int8) {
|
||||
ncclType = ncclInt8;
|
||||
}
|
||||
size_t count = op->getInputs(0)->size();
|
||||
|
||||
ncclComm_t comm =
|
||||
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getNcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
|
||||
getRedOp(), comm, 0));
|
||||
checkNcclError(
|
||||
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
|
||||
}
|
||||
|
||||
virtual ncclRedOp_t getRedOp() const = 0;
|
||||
|
|
|
@ -65,7 +65,7 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
|||
cudnnOpTensorDescriptor_t opDesc;
|
||||
checkCudnnError(cudnnCreateOpTensorDescriptor(&opDesc));
|
||||
checkCudnnError(cudnnSetOpTensorDescriptor(
|
||||
opDesc, getOpType(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN));
|
||||
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
|
||||
|
||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||
|
||||
|
|
Loading…
Reference in New Issue