feat: support int8 llama

This commit is contained in:
kilinchange 2023-12-27 15:28:05 +08:00
parent 85de28ef1e
commit e5ca66db66
2 changed files with 26 additions and 11 deletions

View File

@ -137,11 +137,19 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
place[node.output[0]] = Shard(list(perm).index(plc.dim))
def shard_node(node: NodeProto):
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
if node.op_type in [
"Relu",
"Tanh",
"Softmax",
"Cast",
"DequantizeLinear",
"DynamicQuantizeLinear",
"Sigmoid",
]:
place[node.output[0]] = place[node.input[0]]
elif node.op_type in ["Where"]:
place[node.output[0]] = place[node.input[1]]
if node.op_type in {"Add", "Mul", "Div", "Max"}:
if node.op_type in {"Add", "Mul", "Div", "Max", "Pow"}:
shard_binary(node)
elif node.op_type == "Reshape":
shard_reshape(node)
@ -149,7 +157,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
shard_transpose(node)
elif node.op_type == "Split":
shard_split(node)
elif node.op_type == "MatMul":
elif node.op_type in {"MatMul", "MatMulInteger"}:
assert (
place[node.input[0]] == place[node.input[1]]
), f"{place[node.input[0]]} != {place[node.input[1]]}"
@ -173,17 +181,22 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
for index, node in enumerate(model.graph.node):
nodes.append(node)
# linear
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
input in data for input in node.input
):
if (
node.op_type == "MatMul"
or node.op_type == "Gemm"
or node.op_type == "MatMulInteger"
) and any(input in data for input in node.input):
# FIXME(constroy): the last MatMul should not be sharded as TP.
if (
node.output[0] in output
or (
index + 1 < len(model.graph.node)
and model.graph.node[index + 1].output[0]
and model.graph.node[index + 1].output[0] in output
)
or (
index + 2 < len(model.graph.node)
and model.graph.node[index + 2].output[0] in output
)
in output
):
continue
groups = 1
@ -209,7 +222,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
)
place[node.output[0]] = Replicate()
node.output[0] = new_name
if len(node.input) > 2: # split bias to add
if (
len(node.input) > 2 and node.op_type != "MatMulInteger"
): # split bias to add
prev = nodes[-1]
new_name = prev.output[0] + "_no_bias"
place[new_name] = place[node.output[0]]

View File

@ -29,7 +29,7 @@ __global__ void _subB_kernel(void *a, void *b, int size, int k, int n,
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
int j = delta * (i / k) + (i % n);
int j = delta * ((i / k) + (i % n));
((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j];
}
}
@ -39,7 +39,7 @@ __global__ void _subB_u8_kernel(void *a, void *b, int size, int k, int n,
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
int j = delta * (i / k) + (i % n);
int j = delta * ((i / k) + (i % n));
auto aData = static_cast<int16_t>(((uint8_t *)a)[i]);
auto bData = static_cast<int16_t>(((uint8_t *)b)[j]);
((int8_t *)a)[i] = static_cast<int8_t>(aData - bData);