forked from jiuyuan/InfiniTensor
feat: support int8 llama
This commit is contained in:
parent
85de28ef1e
commit
e5ca66db66
|
@ -137,11 +137,19 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||
|
||||
def shard_node(node: NodeProto):
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
|
||||
if node.op_type in [
|
||||
"Relu",
|
||||
"Tanh",
|
||||
"Softmax",
|
||||
"Cast",
|
||||
"DequantizeLinear",
|
||||
"DynamicQuantizeLinear",
|
||||
"Sigmoid",
|
||||
]:
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type in ["Where"]:
|
||||
place[node.output[0]] = place[node.input[1]]
|
||||
if node.op_type in {"Add", "Mul", "Div", "Max"}:
|
||||
if node.op_type in {"Add", "Mul", "Div", "Max", "Pow"}:
|
||||
shard_binary(node)
|
||||
elif node.op_type == "Reshape":
|
||||
shard_reshape(node)
|
||||
|
@ -149,7 +157,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
shard_transpose(node)
|
||||
elif node.op_type == "Split":
|
||||
shard_split(node)
|
||||
elif node.op_type == "MatMul":
|
||||
elif node.op_type in {"MatMul", "MatMulInteger"}:
|
||||
assert (
|
||||
place[node.input[0]] == place[node.input[1]]
|
||||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||
|
@ -173,17 +181,22 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
for index, node in enumerate(model.graph.node):
|
||||
nodes.append(node)
|
||||
# linear
|
||||
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
||||
input in data for input in node.input
|
||||
):
|
||||
if (
|
||||
node.op_type == "MatMul"
|
||||
or node.op_type == "Gemm"
|
||||
or node.op_type == "MatMulInteger"
|
||||
) and any(input in data for input in node.input):
|
||||
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||
if (
|
||||
node.output[0] in output
|
||||
or (
|
||||
index + 1 < len(model.graph.node)
|
||||
and model.graph.node[index + 1].output[0]
|
||||
and model.graph.node[index + 1].output[0] in output
|
||||
)
|
||||
or (
|
||||
index + 2 < len(model.graph.node)
|
||||
and model.graph.node[index + 2].output[0] in output
|
||||
)
|
||||
in output
|
||||
):
|
||||
continue
|
||||
groups = 1
|
||||
|
@ -209,7 +222,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
)
|
||||
place[node.output[0]] = Replicate()
|
||||
node.output[0] = new_name
|
||||
if len(node.input) > 2: # split bias to add
|
||||
if (
|
||||
len(node.input) > 2 and node.op_type != "MatMulInteger"
|
||||
): # split bias to add
|
||||
prev = nodes[-1]
|
||||
new_name = prev.output[0] + "_no_bias"
|
||||
place[new_name] = place[node.output[0]]
|
||||
|
|
|
@ -29,7 +29,7 @@ __global__ void _subB_kernel(void *a, void *b, int size, int k, int n,
|
|||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < size; i += stride) {
|
||||
int j = delta * (i / k) + (i % n);
|
||||
int j = delta * ((i / k) + (i % n));
|
||||
((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j];
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ __global__ void _subB_u8_kernel(void *a, void *b, int size, int k, int n,
|
|||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < size; i += stride) {
|
||||
int j = delta * (i / k) + (i % n);
|
||||
int j = delta * ((i / k) + (i % n));
|
||||
auto aData = static_cast<int16_t>(((uint8_t *)a)[i]);
|
||||
auto bData = static_cast<int16_t>(((uint8_t *)b)[j]);
|
||||
((int8_t *)a)[i] = static_cast<int8_t>(aData - bData);
|
||||
|
|
Loading…
Reference in New Issue