forked from jiuyuan/InfiniTensor
feat: support int8 llama
This commit is contained in:
parent
85de28ef1e
commit
e5ca66db66
|
@ -137,11 +137,19 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||||
|
|
||||||
def shard_node(node: NodeProto):
|
def shard_node(node: NodeProto):
|
||||||
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
|
if node.op_type in [
|
||||||
|
"Relu",
|
||||||
|
"Tanh",
|
||||||
|
"Softmax",
|
||||||
|
"Cast",
|
||||||
|
"DequantizeLinear",
|
||||||
|
"DynamicQuantizeLinear",
|
||||||
|
"Sigmoid",
|
||||||
|
]:
|
||||||
place[node.output[0]] = place[node.input[0]]
|
place[node.output[0]] = place[node.input[0]]
|
||||||
elif node.op_type in ["Where"]:
|
elif node.op_type in ["Where"]:
|
||||||
place[node.output[0]] = place[node.input[1]]
|
place[node.output[0]] = place[node.input[1]]
|
||||||
if node.op_type in {"Add", "Mul", "Div", "Max"}:
|
if node.op_type in {"Add", "Mul", "Div", "Max", "Pow"}:
|
||||||
shard_binary(node)
|
shard_binary(node)
|
||||||
elif node.op_type == "Reshape":
|
elif node.op_type == "Reshape":
|
||||||
shard_reshape(node)
|
shard_reshape(node)
|
||||||
|
@ -149,7 +157,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
shard_transpose(node)
|
shard_transpose(node)
|
||||||
elif node.op_type == "Split":
|
elif node.op_type == "Split":
|
||||||
shard_split(node)
|
shard_split(node)
|
||||||
elif node.op_type == "MatMul":
|
elif node.op_type in {"MatMul", "MatMulInteger"}:
|
||||||
assert (
|
assert (
|
||||||
place[node.input[0]] == place[node.input[1]]
|
place[node.input[0]] == place[node.input[1]]
|
||||||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||||
|
@ -173,17 +181,22 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
for index, node in enumerate(model.graph.node):
|
for index, node in enumerate(model.graph.node):
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
# linear
|
# linear
|
||||||
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
if (
|
||||||
input in data for input in node.input
|
node.op_type == "MatMul"
|
||||||
):
|
or node.op_type == "Gemm"
|
||||||
|
or node.op_type == "MatMulInteger"
|
||||||
|
) and any(input in data for input in node.input):
|
||||||
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||||
if (
|
if (
|
||||||
node.output[0] in output
|
node.output[0] in output
|
||||||
or (
|
or (
|
||||||
index + 1 < len(model.graph.node)
|
index + 1 < len(model.graph.node)
|
||||||
and model.graph.node[index + 1].output[0]
|
and model.graph.node[index + 1].output[0] in output
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
index + 2 < len(model.graph.node)
|
||||||
|
and model.graph.node[index + 2].output[0] in output
|
||||||
)
|
)
|
||||||
in output
|
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
groups = 1
|
groups = 1
|
||||||
|
@ -209,7 +222,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
)
|
)
|
||||||
place[node.output[0]] = Replicate()
|
place[node.output[0]] = Replicate()
|
||||||
node.output[0] = new_name
|
node.output[0] = new_name
|
||||||
if len(node.input) > 2: # split bias to add
|
if (
|
||||||
|
len(node.input) > 2 and node.op_type != "MatMulInteger"
|
||||||
|
): # split bias to add
|
||||||
prev = nodes[-1]
|
prev = nodes[-1]
|
||||||
new_name = prev.output[0] + "_no_bias"
|
new_name = prev.output[0] + "_no_bias"
|
||||||
place[new_name] = place[node.output[0]]
|
place[new_name] = place[node.output[0]]
|
||||||
|
|
|
@ -29,7 +29,7 @@ __global__ void _subB_kernel(void *a, void *b, int size, int k, int n,
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
int stride = blockDim.x * gridDim.x;
|
||||||
for (int i = index; i < size; i += stride) {
|
for (int i = index; i < size; i += stride) {
|
||||||
int j = delta * (i / k) + (i % n);
|
int j = delta * ((i / k) + (i % n));
|
||||||
((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j];
|
((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ __global__ void _subB_u8_kernel(void *a, void *b, int size, int k, int n,
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
int stride = blockDim.x * gridDim.x;
|
||||||
for (int i = index; i < size; i += stride) {
|
for (int i = index; i < size; i += stride) {
|
||||||
int j = delta * (i / k) + (i % n);
|
int j = delta * ((i / k) + (i % n));
|
||||||
auto aData = static_cast<int16_t>(((uint8_t *)a)[i]);
|
auto aData = static_cast<int16_t>(((uint8_t *)a)[i]);
|
||||||
auto bData = static_cast<int16_t>(((uint8_t *)b)[j]);
|
auto bData = static_cast<int16_t>(((uint8_t *)b)[j]);
|
||||||
((int8_t *)a)[i] = static_cast<int8_t>(aData - bData);
|
((int8_t *)a)[i] = static_cast<int8_t>(aData - bData);
|
||||||
|
|
Loading…
Reference in New Issue