From e5ca66db66d09f2ef1e144c20b1223e77a8e472d Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 27 Dec 2023 15:28:05 +0800 Subject: [PATCH] feat: support int8 llama --- examples/distributed/parallel_opt.py | 33 ++++++++++++++++++++-------- src/kernels/cuda/matmul_integer.cu | 4 ++-- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 28e3cbca..e088039a 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -137,11 +137,19 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): place[node.output[0]] = Shard(list(perm).index(plc.dim)) def shard_node(node: NodeProto): - if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]: + if node.op_type in [ + "Relu", + "Tanh", + "Softmax", + "Cast", + "DequantizeLinear", + "DynamicQuantizeLinear", + "Sigmoid", + ]: place[node.output[0]] = place[node.input[0]] elif node.op_type in ["Where"]: place[node.output[0]] = place[node.input[1]] - if node.op_type in {"Add", "Mul", "Div", "Max"}: + if node.op_type in {"Add", "Mul", "Div", "Max", "Pow"}: shard_binary(node) elif node.op_type == "Reshape": shard_reshape(node) @@ -149,7 +157,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): shard_transpose(node) elif node.op_type == "Split": shard_split(node) - elif node.op_type == "MatMul": + elif node.op_type in {"MatMul", "MatMulInteger"}: assert ( place[node.input[0]] == place[node.input[1]] ), f"{place[node.input[0]]} != {place[node.input[1]]}" @@ -173,17 +181,22 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): for index, node in enumerate(model.graph.node): nodes.append(node) # linear - if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( - input in data for input in node.input - ): + if ( + node.op_type == "MatMul" + or node.op_type == "Gemm" + or node.op_type == "MatMulInteger" + ) and any(input in data for input in node.input): # FIXME(constroy): the last MatMul should not be sharded as TP. if ( node.output[0] in output or ( index + 1 < len(model.graph.node) - and model.graph.node[index + 1].output[0] + and model.graph.node[index + 1].output[0] in output + ) + or ( + index + 2 < len(model.graph.node) + and model.graph.node[index + 2].output[0] in output ) - in output ): continue groups = 1 @@ -209,7 +222,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): ) place[node.output[0]] = Replicate() node.output[0] = new_name - if len(node.input) > 2: # split bias to add + if ( + len(node.input) > 2 and node.op_type != "MatMulInteger" + ): # split bias to add prev = nodes[-1] new_name = prev.output[0] + "_no_bias" place[new_name] = place[node.output[0]] diff --git a/src/kernels/cuda/matmul_integer.cu b/src/kernels/cuda/matmul_integer.cu index 9e56e88f..c458a02e 100644 --- a/src/kernels/cuda/matmul_integer.cu +++ b/src/kernels/cuda/matmul_integer.cu @@ -29,7 +29,7 @@ __global__ void _subB_kernel(void *a, void *b, int size, int k, int n, int index = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i = index; i < size; i += stride) { - int j = delta * (i / k) + (i % n); + int j = delta * ((i / k) + (i % n)); ((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j]; } } @@ -39,7 +39,7 @@ __global__ void _subB_u8_kernel(void *a, void *b, int size, int k, int n, int index = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i = index; i < size; i += stride) { - int j = delta * (i / k) + (i % n); + int j = delta * ((i / k) + (i % n)); auto aData = static_cast(((uint8_t *)a)[i]); auto bData = static_cast(((uint8_t *)b)[j]); ((int8_t *)a)[i] = static_cast(aData - bData);