forked from jiuyuan/InfiniTensor
Compare commits
19 Commits
master
...
kvcache_ba
Author | SHA1 | Date |
---|---|---|
xiaonans | b0d030d0de | |
xiaonans | d000f9750c | |
xiaonans | 4a5b9572bb | |
xiaonans | 159642d6ae | |
xiaonans | c01e64db50 | |
xiaonans | eb3a2d123d | |
xiaonans | 4bdd33522b | |
xiaonans | 0740d26f43 | |
xiaonans | fc3d38f80e | |
xiaonans | d43364ac60 | |
xiaonans | db053e32a4 | |
xiaonans | 1e797d4ffe | |
xiaonans | 80412ae162 | |
xiaonans | 83be7fa373 | |
xiaonans | 0f1c04d864 | |
xiaonans | 936797b960 | |
xiaonans | 17bd98d453 | |
xiaonans | 8cc6af0a83 | |
xiaonans | c04910f118 |
|
@ -1 +1 @@
|
|||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
||||
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
|
@ -0,0 +1,512 @@
|
|||
import os
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import onnx_graphsurgeon as gs
|
||||
import time
|
||||
import nvtx
|
||||
import argparse
|
||||
from mpi4py import MPI
|
||||
from pytrie import StringTrie
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
IO,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||
parser.add_argument('--layer', dest='n_layers', type=int, default=48)
|
||||
parser.add_argument("--num_nodes", dest='num_nodes',
|
||||
type=int, default=1, help="number of nodes")
|
||||
parser.add_argument("--world_size", dest="world_size",
|
||||
type=int, default=1, help="")
|
||||
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
|
||||
type=int, default=1, help="number of processes per node")
|
||||
parser.add_argument("--n_max_length", dest="n_max_length",
|
||||
type=int, default=1024, help="number of processes per node")
|
||||
parser.add_argument("--vocab_size", dest="vocab_size",
|
||||
type=int, default=119696, help="vocabulary size")
|
||||
parser.add_argument("--hidden_size", dest="hidden_size",
|
||||
type=int, default=4096, help="vocabulary size")
|
||||
parser.add_argument('--rank', dest='rank', type=int, default=0)
|
||||
parser.add_argument('--speedup', action='store_true')
|
||||
parser.add_argument('--no_cudagraph', action='store_true')
|
||||
parser.add_argument('--fp16', action='store_true')
|
||||
args = parser.parse_args()
|
||||
comm = MPI.COMM_WORLD
|
||||
args.rank = comm.Get_rank()
|
||||
args.nproc_per_node = comm.Get_size()
|
||||
args.world_size = args.num_nodes * args.nproc_per_node
|
||||
|
||||
ONNX_MODEL_PATH = "/data3/shared/xnsong/9G/dist/9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
|
||||
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||
|
||||
weight_path = "9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
|
||||
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||
|
||||
model_dir = "/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/cpm9g-11b-sft.pt"
|
||||
|
||||
@gs.Graph.register()
|
||||
def RMSNorm(self, a, b):
|
||||
return self.layer(op="RMSNorm", inputs=a, outputs=b)
|
||||
|
||||
@gs.Graph.register()
|
||||
def RoPE(self, a, b):
|
||||
return self.layer(op="RoPE", inputs=a, outputs=b)
|
||||
|
||||
@gs.Graph.register()
|
||||
def AttentionKVCache(self, a, b):
|
||||
return self.layer(op="AttentionKVCache", inputs=a, outputs=b)
|
||||
|
||||
def to_numpy(dict):
|
||||
ret = dict
|
||||
if args.fp16:
|
||||
ret = np.float16(ret)
|
||||
else:
|
||||
ret = np.float32(ret)
|
||||
return ret
|
||||
|
||||
def parallel(array, split='replicate'):
|
||||
if args.world_size > 1 and split == 'partial_column':
|
||||
return np.hsplit(array, args.world_size)[args.rank]
|
||||
elif args.world_size > 1 and split == 'partial_row':
|
||||
return np.vsplit(array, args.world_size)[args.rank]
|
||||
return array
|
||||
|
||||
|
||||
def generate_onnx(ONNX_MODEL_PATH):
|
||||
state_dict = torch.load(f'{model_dir}', map_location='cpu')
|
||||
new_state_dict = {name: param.cpu().numpy()
|
||||
for name, param in state_dict.items()
|
||||
}
|
||||
|
||||
operators = []
|
||||
graph = gs.Graph(nodes=operators)
|
||||
gather_input = gs.Variable(name="gather_input.0", dtype=np.int64, shape=(1,1))
|
||||
pos_input = gs.Variable(name="pos_input.0", dtype=np.int64, shape=(1,1))
|
||||
|
||||
embedding_weight = gs.Constant(name="embedding.weight", values=to_numpy(new_state_dict["input_embedding.weight"]))
|
||||
gather_output = gs.Variable(name="gather_output.0")
|
||||
gather = gs.Node(op="Gather", inputs=[embedding_weight, gather_input], outputs=[gather_output])
|
||||
operators.append(gather)
|
||||
input = gather_output
|
||||
|
||||
graph.inputs=[gather_input, pos_input]
|
||||
graph.outputs=[]
|
||||
|
||||
for i in tqdm(range(args.n_layers)):
|
||||
# global input
|
||||
attn_kcache_input = gs.Variable(name="/layers." + str(i) + "/attn/kcache_input", dtype=np.float32, shape=(1,32,1023,128))
|
||||
attn_vcache_input = gs.Variable(name="/layers." + str(i) + "/attn/vcache_input", dtype=np.float32, shape=(1,32,1023,128))
|
||||
graph.inputs.append(attn_kcache_input)
|
||||
graph.inputs.append(attn_vcache_input)
|
||||
|
||||
# weight
|
||||
layernorm_0_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.0/mul_weight",
|
||||
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".self_att.layernorm_before_attention.weight"]))
|
||||
attn_qproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/qproj_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_q.weight"]))
|
||||
, 'partial_column'))
|
||||
attn_kproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/kproj_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_k.weight"]))
|
||||
, 'partial_column'))
|
||||
attn_vproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/vproj_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_v.weight"]))
|
||||
, 'partial_column'))
|
||||
attn_outmatmul_input = gs.Constant(name="/layers." + str(i) + "/attn/outmatmul_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.attention_out.weight"]))
|
||||
, 'partial_row'))
|
||||
|
||||
layernorm_1_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.1/mul_weight",
|
||||
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".ffn.layernorm_before_ffn.weight"]))
|
||||
ffn_matmul_0_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_0_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_0.weight"]))
|
||||
, 'partial_column'))
|
||||
ffn_matmul_1_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_1_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_1.weight"]))
|
||||
, 'partial_column'))
|
||||
ffn_matmul_out_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_out_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_out.weight"]))
|
||||
, 'partial_row'))
|
||||
|
||||
attn_qrope_output = gs.Variable(name="/layers." + str(i) + "/attn/qrope_output")
|
||||
attn_krope_output = gs.Variable(name="/layers." + str(i) + "/attn/krope_output")
|
||||
attn_kvcache_output = gs.Variable(name="/layers." + str(i) + "/attn/kvcache_output")
|
||||
layernorm_0_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.0/mul_output_1")
|
||||
layernorm_1_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.1/mul_output_1")
|
||||
attn_qproj_output = gs.Variable(name="/layers." + str(i) + "/attn/qproj_output")
|
||||
attn_kproj_output = gs.Variable(name="/layers." + str(i) + "/attn/kproj_output")
|
||||
attn_vproj_output = gs.Variable(name="/layers." + str(i) + "/attn/vproj_output")
|
||||
attn_outmatmul_output = gs.Variable(name="/layers." + str(i) + "/attn/outmatmul_output")
|
||||
attn_outadd_output = gs.Variable(name="/layers." + str(i) + "/attn/outadd_output")
|
||||
ffn_matmul_0_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_0_output")
|
||||
ffn_silu_output = gs.Variable(name="/layers." + str(i) + "/ffn/silu_output")
|
||||
ffn_matmul_1_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_1_output")
|
||||
ffn_mul_output = gs.Variable(name="/layers." + str(i) + "/ffn/mul_output")
|
||||
ffn_matmul_out_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_out_output")
|
||||
ffn_add_output = gs.Variable(name="/layers." + str(i) + "/ffn/add_output")
|
||||
|
||||
graph.RMSNorm([input, layernorm_0_mul_weight], [layernorm_0_mul_output_1])
|
||||
attn_qproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_qproj_weight], outputs=[attn_qproj_output])
|
||||
operators.append(attn_qproj)
|
||||
attn_kproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_kproj_weight], outputs=[attn_kproj_output])
|
||||
operators.append(attn_kproj)
|
||||
attn_vproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_vproj_weight], outputs=[attn_vproj_output])
|
||||
operators.append(attn_vproj)
|
||||
graph.RoPE([pos_input, attn_qproj_output], [attn_qrope_output])
|
||||
graph.RoPE([pos_input, attn_kproj_output], [attn_krope_output])
|
||||
graph.AttentionKVCache([attn_kcache_input, attn_vcache_input, attn_qrope_output, attn_krope_output, attn_vproj_output, pos_input],[attn_kvcache_output])
|
||||
attn_outproj = gs.Node(op="MatMul", inputs=[attn_kvcache_output, attn_outmatmul_input], outputs=[attn_outmatmul_output])
|
||||
operators.append(attn_outproj)
|
||||
|
||||
attn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/attn/reducesum_output")
|
||||
if args.world_size > 1:
|
||||
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/attn/reducesum",
|
||||
inputs=[attn_outmatmul_output], outputs=[attn_reduce_sum_output],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum)
|
||||
|
||||
attn_outadd = gs.Node(op="Add", inputs=[input, attn_outmatmul_output if args.world_size == 1 else attn_reduce_sum_output], outputs=[attn_outadd_output])
|
||||
operators.append(attn_outadd)
|
||||
|
||||
graph.RMSNorm([attn_outadd_output, layernorm_1_mul_weight], [layernorm_1_mul_output_1])
|
||||
|
||||
ffn_matmul_0 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_0_input], outputs=[ffn_matmul_0_output])
|
||||
operators.append(ffn_matmul_0)
|
||||
ffn_silu = gs.Node(op="Silu", inputs=[ffn_matmul_0_output], outputs=[ffn_silu_output])
|
||||
operators.append(ffn_silu)
|
||||
ffn_matmul_1 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_1_input], outputs=[ffn_matmul_1_output])
|
||||
operators.append(ffn_matmul_1)
|
||||
ffn_mul = gs.Node(op="Mul", inputs=[ffn_silu_output, ffn_matmul_1_output], outputs=[ffn_mul_output])
|
||||
operators.append(ffn_mul)
|
||||
ffn_matmul_out = gs.Node(op="MatMul", inputs=[ffn_mul_output, ffn_matmul_out_input], outputs=[ffn_matmul_out_output])
|
||||
operators.append(ffn_matmul_out)
|
||||
|
||||
ffn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/ffn/reducesum_output")
|
||||
if args.world_size > 1:
|
||||
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/ffn/reducesum",
|
||||
inputs=[ffn_matmul_out_output], outputs=[ffn_reduce_sum_output],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum)
|
||||
|
||||
ffn_add = gs.Node(op="Add", inputs=[attn_outadd_output, ffn_matmul_out_output if args.world_size == 1 else ffn_reduce_sum_output], outputs=[ffn_add_output])
|
||||
operators.append(ffn_add)
|
||||
input = ffn_add_output
|
||||
|
||||
layernorm_mul_weight = gs.Constant(name="/output/layernorm/mul_weight", values=to_numpy(new_state_dict["encoder.output_layernorm.weight"]))
|
||||
layernorm_mul_output_1 = gs.Variable(name="/output/layernorm/mul_output_1")
|
||||
|
||||
graph.RMSNorm([input, layernorm_mul_weight], [layernorm_mul_output_1])
|
||||
|
||||
lm_head_weight = gs.Constant(name="/output/lm_head/weight", values=np.transpose(to_numpy(new_state_dict["lm_head.weight"])))
|
||||
lm_head_output = gs.Variable(name="/output/lm_head/output")
|
||||
lm_head = gs.Node(op="MatMul", inputs=[layernorm_mul_output_1, lm_head_weight], outputs=[lm_head_output])
|
||||
operators.append(lm_head)
|
||||
|
||||
if args.fp16:
|
||||
final_cast_output = gs.Variable(name="/output/cast/output", dtype=np.float32, shape=(1,1,args.vocab_size))
|
||||
final_cast = gs.Node(op="Cast", inputs=[lm_head_output], outputs=[final_cast_output])
|
||||
final_cast.attrs["to"] = np.float32
|
||||
operators.append(final_cast)
|
||||
graph.outputs.append(final_cast_output)
|
||||
else:
|
||||
lm_head_output.dtype=np.float32
|
||||
lm_head_output.shape=(1,1,args.vocab_size)
|
||||
graph.outputs.append(lm_head_output)
|
||||
|
||||
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True, location=weight_path)
|
||||
return
|
||||
|
||||
|
||||
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab: Dict[str, int] = {}
|
||||
|
||||
reader = io.TextIOWrapper(fp, encoding="utf-8")
|
||||
for token in reader.readlines():
|
||||
token = token.strip()
|
||||
if len(token) == 0:
|
||||
continue
|
||||
token = json.loads(token)
|
||||
vocab[token] = len(vocab)
|
||||
return vocab
|
||||
|
||||
|
||||
class CPM9GTokenizer(object):
|
||||
def __init__(self, path):
|
||||
self.unk_token = "<unk>"
|
||||
self.bos_token = "<s>"
|
||||
self.eos_token = "</s>"
|
||||
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
|
||||
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
|
||||
]
|
||||
|
||||
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
|
||||
|
||||
all_tokens = load_vocab(io.FileIO(path, "rb"))
|
||||
|
||||
self.encoder: Dict[str, int] = {}
|
||||
self._special_encoder: Dict[str, int] = {}
|
||||
for token, token_id in all_tokens.items():
|
||||
if token in self._special_token_set:
|
||||
self._special_encoder[token] = token_id
|
||||
else:
|
||||
self.encoder[token] = token_id
|
||||
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
|
||||
|
||||
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
||||
|
||||
self._len_word_first = {}
|
||||
for x in self.encoder.keys():
|
||||
if not x[0] in self._len_word_first:
|
||||
self._len_word_first[x[0]] = 1
|
||||
if len(x) > self._len_word_first[x[0]]:
|
||||
self._len_word_first[x[0]] = len(x)
|
||||
self.tencoder = StringTrie(self.encoder)
|
||||
|
||||
def get_piece(self, text: str) -> str:
|
||||
if text[0] in self._len_word_first:
|
||||
text = text[: self._len_word_first[text[0]]]
|
||||
len_text = len(text)
|
||||
for i in range(len(text)):
|
||||
sub = text[: len_text - i]
|
||||
if sub in self.encoder:
|
||||
return sub
|
||||
return text[0]
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self._special_encoder[self.eos_token]
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self._special_encoder[self.bos_token]
|
||||
|
||||
@property
|
||||
def unk_id(self):
|
||||
return self._special_encoder[self.unk_token]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self._special_encoder)
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
output_tokens: List[str] = []
|
||||
st = 0
|
||||
while st < len(text):
|
||||
piece = self.get_piece(text[st:])
|
||||
output_tokens.append(piece)
|
||||
st += len(piece)
|
||||
return output_tokens
|
||||
|
||||
@staticmethod
|
||||
def escape(text: str) -> str:
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def unescape(text: str) -> str:
|
||||
return text
|
||||
|
||||
def encode(self, text: str, with_bos = True) -> List[int]:
|
||||
ret = []
|
||||
if with_bos:
|
||||
ret.append(self.bos_id)
|
||||
for x in self.tokenize(text):
|
||||
if x in self.encoder:
|
||||
ret.append(self.encoder[x])
|
||||
else:
|
||||
ret.extend(self._encode_unicode(x))
|
||||
return ret
|
||||
|
||||
def decode(self, tokens: List[int]):
|
||||
"""Decode ids into a string."""
|
||||
ret = []
|
||||
st = 0
|
||||
while st < len(tokens):
|
||||
if tokens[st] in self.decoder:
|
||||
ret.append(self.decoder[tokens[st]])
|
||||
st += 1
|
||||
elif tokens[st] in self._byte_decoder:
|
||||
first = self._byte_decoder[tokens[st]]
|
||||
length = 1 if first < 128 else len(re.search('^1+0', bin(first)[2:])[0])-1
|
||||
code = 0
|
||||
try:
|
||||
for j in range(length):
|
||||
code = code << 8 | self._byte_decoder[tokens[st + j]]
|
||||
code = int.to_bytes(code, length, "big").decode("utf-8")
|
||||
ret.append(code)
|
||||
except:
|
||||
pass
|
||||
st = st + length
|
||||
elif tokens[st] == self.eos_id:
|
||||
ret.append(self.eos_token)
|
||||
st += 1
|
||||
elif tokens[st] == self.bos_id:
|
||||
ret.append(self.bos_token)
|
||||
st += 1
|
||||
else:
|
||||
ret.append(self.unk_token)
|
||||
st += 1
|
||||
return "".join(ret)
|
||||
|
||||
def _encode_unicode(self, token):
|
||||
# wrap unicode encoding into a helper function
|
||||
ids = []
|
||||
utf8_id = token.encode("utf-8")
|
||||
for _id in utf8_id:
|
||||
ids.append(self._special_encoder[self.byte_list[_id]])
|
||||
return ids
|
||||
|
||||
def next_token(self, text):
|
||||
# fast next token matching
|
||||
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
|
||||
if token is None:
|
||||
token = text[0]
|
||||
token_ids = self._encode_unicode(token)
|
||||
else:
|
||||
token_ids = [token_id]
|
||||
return token, token_ids
|
||||
|
||||
|
||||
def start_worker(
|
||||
world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, query
|
||||
):
|
||||
model = onnx.load(ONNX_MODEL_PATH)
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
if args.nproc_per_node > 1:
|
||||
runtime.init_comm(
|
||||
"9g",
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
print("[{}] comm init.".format(rank))
|
||||
|
||||
stub = OnnxStub(model, runtime)
|
||||
print("[{}] stub init.".format(rank))
|
||||
|
||||
for i in range(10):
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
print("[{}] stub warmup.".format(rank))
|
||||
|
||||
tokenizer = CPM9GTokenizer("/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/vocabs.txt")
|
||||
query = tokenizer.encode(query)
|
||||
|
||||
output_tokens = []
|
||||
for i in range(len(query)):
|
||||
q = np.array(query[i])
|
||||
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
|
||||
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
|
||||
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
|
||||
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
|
||||
if i == len(query) - 1:
|
||||
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
|
||||
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
|
||||
q = np.argmax(output)
|
||||
output_tokens.append(q)
|
||||
|
||||
avg_time = 0
|
||||
count = 0
|
||||
while i < 1000:
|
||||
count = count + 1
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("gen {}-th token".format(i), color="red"):
|
||||
i = i + 1
|
||||
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
|
||||
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
|
||||
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
|
||||
|
||||
t0 = time.time()
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
t1 = time.time()
|
||||
avg_time += t1 - t0
|
||||
|
||||
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
|
||||
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
|
||||
|
||||
# print(output)
|
||||
|
||||
with nvtx.annotate("argmax".format(i), color="green"):
|
||||
q = np.argmax(output)
|
||||
if q == 2:
|
||||
break
|
||||
|
||||
output_tokens.append(q)
|
||||
avg_time = avg_time / count
|
||||
print("avg_time_cost =", avg_time*1000, "ms")
|
||||
text = tokenizer.decode(output_tokens)
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
comm = MPI.COMM_WORLD
|
||||
args.rank = comm.Get_rank()
|
||||
args.nproc_per_node = comm.Get_size()
|
||||
world_size = args.num_nodes * args.nproc_per_node
|
||||
|
||||
if not os.path.exists(ONNX_MODEL_PATH):
|
||||
print("exporting onnx graph")
|
||||
generate_onnx(ONNX_MODEL_PATH)
|
||||
else:
|
||||
print("will use exsiting onnx graph")
|
||||
onnx_model = onnx.load(ONNX_MODEL_PATH)
|
||||
print("data loaded")
|
||||
|
||||
|
||||
#query = '''Beijing is the captial'''
|
||||
#query = '''什么是PTX?'''
|
||||
#query = '''生病了怎么办?'''
|
||||
#query = '''Happy'''
|
||||
query = '''def gcd(a, b):'''
|
||||
|
||||
####################################
|
||||
# infinitensor dist
|
||||
####################################
|
||||
# run distributed parallel.
|
||||
pred = start_worker(world_size, args.rank, args.rank %
|
||||
args.nproc_per_node, onnx_model, query)
|
||||
if args.rank == 0:
|
||||
print("输入:\n\n", query, "\n")
|
||||
print("输出:", pred)
|
|
@ -0,0 +1,491 @@
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import torch
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import os
|
||||
import numpy as np
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import time
|
||||
import nvtx
|
||||
from mpi4py import MPI
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||
parser.add_argument('--layer', dest='n_layers', type=int, default=32)
|
||||
parser.add_argument("--num_nodes", dest='num_nodes',
|
||||
type=int, default=1, help="number of nodes")
|
||||
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
|
||||
type=int, default=1, help="number of processes per node")
|
||||
parser.add_argument("--world_size", dest="world_size",
|
||||
type=int, default=1, help="")
|
||||
parser.add_argument("--n_max_length", dest="n_max_length",
|
||||
type=int, default=1024, help="")
|
||||
parser.add_argument("--vocab_size", dest="vocab_size",
|
||||
type=int, default=32000, help="vocabulary size")
|
||||
parser.add_argument("--hidden_size", dest="hidden_size",
|
||||
type=int, default=4096)
|
||||
parser.add_argument("--head_size", dest="head_size",
|
||||
type=int, default=32)
|
||||
parser.add_argument("--head_dim", dest="head_dim",
|
||||
type=int, default=128)
|
||||
parser.add_argument('--rank', dest='rank', type=int, default=0)
|
||||
parser.add_argument('--no_cudagraph', action='store_true')
|
||||
parser.add_argument('--fp16', action='store_true')
|
||||
parser.add_argument('--is_1st_graph', action='store_true')
|
||||
parser.add_argument('--speedup', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
args.rank = comm.Get_rank()
|
||||
args.nproc_per_node = comm.Get_size()
|
||||
args.world_size = args.num_nodes * args.nproc_per_node
|
||||
|
||||
PRETRAINED_LLAMA_PATH = "/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/"
|
||||
ONNX_MODEL_PATH = "/data3/shared/xnsong/llama2/" + ("1st" if args.is_1st_graph else "2nd")
|
||||
ONNX_MODEL_ORIGIN_PATH = ONNX_MODEL_PATH + "/origin/llama2_origin_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_SIM_PATH = ONNX_MODEL_PATH + "/sim/llama2_sim_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_FUSION_PATH = ONNX_MODEL_PATH + "/fusion/llama2_fusion_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_SPECIAL_PATH = ONNX_MODEL_PATH + "/special/llama2_special_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_FP16_PATH = ONNX_MODEL_PATH + "/fp16/llama2_fp16_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_DIST_PATH = ONNX_MODEL_PATH + "/dist/llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
|
||||
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||
|
||||
def parallel_model(onnx_model, world_size, rank):
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
tmap = graph.tensors()
|
||||
|
||||
for i in range(args.n_layers):
|
||||
tmap[graph.inputs[2+i*2].name].shape[1] = tmap[graph.inputs[2+i*2].name].shape[1]//world_size
|
||||
tmap[graph.inputs[3+i*2].name].shape[1] = tmap[graph.inputs[3+i*2].name].shape[1]//world_size
|
||||
for node in graph.nodes:
|
||||
if node.name == "/model/layers." + str(i) + "/self_attn/q_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/k_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/v_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/o_proj/MatMul":
|
||||
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
|
||||
reduce_sum_output = gs.Variable("reduce_sum_output_" + str(i) + "_0",
|
||||
dtype=np.float32)
|
||||
reduce_sum = gs.Node(op="ReduceSum", name="reduce_sum_"+str(i)+"_0",
|
||||
inputs=node.outputs, outputs=[reduce_sum_output],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum)
|
||||
next_node = node.outputs[0].outputs[0]
|
||||
next_node.inputs[1] = reduce_sum_output
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_0" or \
|
||||
node.name == "/model/layers." + str(i) + "/self_attn/Reshape_1":
|
||||
node.inputs[1].values = np.array(
|
||||
[1, 1,
|
||||
args.head_size//world_size,
|
||||
args.hidden_size//args.head_size])
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_2":
|
||||
node.inputs[1] = gs.Constant(name="/model/layers."+str(i)+"/self_attn/vreshape_input",
|
||||
values=np.array(
|
||||
[1, 1,
|
||||
args.head_size//world_size,
|
||||
args.hidden_size//args.head_size]))
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_3":
|
||||
node.inputs[1] = gs.Constant(name="/model/layers." + str(i) + "/self_attn/Reshape_3_shape",
|
||||
values=np.array(
|
||||
[1, 1, args.hidden_size//world_size]))
|
||||
|
||||
elif node.name == "/model/layers." + str(i) + "/mlp/up_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/mlp/gate_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/mlp/down_proj/MatMul":
|
||||
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
|
||||
reduce_sum_output_1 = gs.Variable("reduce_sum_output_" + str(i) + "_1",
|
||||
dtype=np.float32)
|
||||
reduce_sum_1 = gs.Node(op="ReduceSum", inputs=node.outputs, outputs=[reduce_sum_output_1],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum_1)
|
||||
next_node = node.outputs[0].outputs[0]
|
||||
next_node.inputs[1] = reduce_sum_output_1
|
||||
|
||||
# new_out_1 = tmap["/model/layers.0/mlp/down_proj/MatMul_output_0"] #reduce_sum_output
|
||||
# new_out_1.dtype = np.float32
|
||||
# new_out_1.shape = [1,1,4096]
|
||||
# graph.outputs.append(new_out_1)
|
||||
graph.cleanup(True).toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def simplify(onnx_model):
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
for node in graph.nodes:
|
||||
if node.op == "Cast":
|
||||
inp_node = node.i()
|
||||
inp_node.outputs = node.outputs
|
||||
node.outputs.clear()
|
||||
|
||||
for i in range(args.n_layers):
|
||||
nodename = "/model/layers." + str(i) + "/self_attn/Add_2"
|
||||
node = [node for node in graph.nodes if node.name == nodename][0]
|
||||
inp_node = node.i()
|
||||
inp_node.outputs = node.outputs
|
||||
node.outputs.clear()
|
||||
|
||||
graph.cleanup().toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_RMSNorm(self, inputs, outputs):
|
||||
inputs[0].outputs.pop(0)
|
||||
inputs[0].outputs.pop(0)
|
||||
|
||||
for out in outputs:
|
||||
out.inputs.clear()
|
||||
return self.layer(op="RMSNorm", inputs=inputs, outputs=outputs, name="rmsnorm")
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_silu(self, inputs, outputs):
|
||||
for inp in inputs:
|
||||
inp.outputs.clear()
|
||||
for out in outputs:
|
||||
out.inputs.clear()
|
||||
return self.layer(op="Silu", inputs=inputs, outputs=outputs, name="silu")
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_RoPE(self, a, b):
|
||||
return self.layer(op="RoPE", inputs=a, outputs=b, name="rope")
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
|
||||
for inp in inputs:
|
||||
inp.outputs.clear()
|
||||
for out in outputs:
|
||||
out.inputs.clear()
|
||||
for inp in inputs_added:
|
||||
inputs.append(inp)
|
||||
for out in outputs_removed:
|
||||
out.inputs.clear()
|
||||
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs, name="attention")
|
||||
|
||||
def fusion(model):
|
||||
graph = gs.import_onnx(model)
|
||||
tmap = graph.tensors()
|
||||
|
||||
tmap["onnx::Reshape_1"].outputs.clear()
|
||||
|
||||
inputs = [tmap["/model/layers.0/input_layernorm/Cast_output_0"], tmap["model.layers.0.input_layernorm.weight"]]
|
||||
rmsnorm_outputs = [tmap["/model/layers.0/input_layernorm/Mul_1_output_0"]]
|
||||
graph.replace_with_RMSNorm(inputs, rmsnorm_outputs)
|
||||
|
||||
for i in range(args.n_layers):
|
||||
# rotary embedding op
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"].inputs.clear()
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"].inputs.clear()
|
||||
attn_qreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/qreshape_input",
|
||||
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
|
||||
attn_kreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/kreshape_input",
|
||||
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
|
||||
attn_qrope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qrope_output")
|
||||
attn_krope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/krope_output")
|
||||
attn_qreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qreshape_output")
|
||||
attn_kreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/kreshape_output")
|
||||
|
||||
attn_qreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_0", inputs=[attn_qrope_output, attn_qreshape_input], outputs=[attn_qreshape_output])
|
||||
attn_kreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_1", inputs=[attn_krope_output, attn_kreshape_input], outputs=[attn_kreshape_output])
|
||||
attn_qtrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_qreshape_output],
|
||||
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"]])
|
||||
attn_ktrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_kreshape_output],
|
||||
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"]])
|
||||
|
||||
graph.nodes.append(attn_qreshape)
|
||||
graph.nodes.append(attn_kreshape)
|
||||
graph.nodes.append(attn_qtrans)
|
||||
graph.nodes.append(attn_ktrans)
|
||||
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/q_proj/MatMul_output_0"]]
|
||||
graph.replace_with_RoPE(inputs, [attn_qrope_output])
|
||||
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/k_proj/MatMul_output_0"]]
|
||||
graph.replace_with_RoPE(inputs, [attn_krope_output])
|
||||
|
||||
# rms-norm op
|
||||
inputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Cast_output_0"], \
|
||||
tmap["model.layers." + str(i) + ".post_attention_layernorm.weight"]]
|
||||
outputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Mul_1_output_0"]]
|
||||
graph.replace_with_RMSNorm(inputs, outputs)
|
||||
inputs = [tmap["/model/layers." + str(i+1) + "/input_layernorm/Cast_output_0"] if i != args.n_layers-1 else \
|
||||
tmap["/model/norm/Cast_output_0"], \
|
||||
tmap["model.layers." + str(i+1) + ".input_layernorm.weight"] if i != args.n_layers-1 else \
|
||||
tmap["model.norm.weight"]]
|
||||
outputs = [tmap["/model/layers."+ str(i+1) + "/input_layernorm/Mul_1_output_0"]] if i != args.n_layers-1 else \
|
||||
[tmap["/model/norm/Mul_1_output_0"]]
|
||||
graph.replace_with_RMSNorm(inputs, outputs)
|
||||
|
||||
# silu op
|
||||
inputs = [tmap["/model/layers." + str(i) + "/mlp/gate_proj/MatMul_output_0"]]
|
||||
outputs = [tmap["/model/layers." + str(i) + "/mlp/act_fn/Mul_output_0"]]
|
||||
graph.replace_with_silu(inputs, outputs)
|
||||
|
||||
inputs = [
|
||||
tmap["onnx::Concat_" + str((i+1)*2)],
|
||||
tmap["onnx::Concat_" + str((i+1)*2+1)],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
||||
outputs = [
|
||||
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],]
|
||||
|
||||
inputs_added = [graph.inputs[1]]
|
||||
outputs_removed = []
|
||||
graph.replace_with_attention(
|
||||
inputs, outputs, inputs_added, outputs_removed)
|
||||
|
||||
graph.outputs = [tmap[graph.outputs[0].name]]
|
||||
graph.cleanup(True).toposort()
|
||||
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def special_pass(model):
|
||||
graph = gs.import_onnx(model)
|
||||
tmap = graph.tensors()
|
||||
for node in graph.nodes:
|
||||
if node.op == "Transpose" or node.op == "Reshape":
|
||||
inp_node = node.i()
|
||||
inp_node.outputs = node.outputs
|
||||
node.outputs.clear()
|
||||
graph.cleanup(True).toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def convert_to_fp16(model):
|
||||
graph = gs.import_onnx(model)
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.op == "Gather" and node.name == "/model/embed_tokens/Gather":
|
||||
node.inputs[0].values = np.float16(node.inputs[0].values)
|
||||
|
||||
if node.op == "RMSNorm":
|
||||
node.inputs[1].values = np.float16(node.inputs[1].values)
|
||||
|
||||
if node.op == "MatMul":
|
||||
node.inputs[1].values = np.float16(node.inputs[1].values)
|
||||
if node.name == "/lm_head/MatMul":
|
||||
cast_1_out = gs.Variable(node.name+"_cast_out_output_0", dtype=np.float32, shape=node.outputs[0].shape)
|
||||
cast_1 = gs.Node(op="Cast", inputs=[node.outputs[0]], outputs=[cast_1_out])
|
||||
cast_1.attrs["to"] = np.float32
|
||||
cast_1.name = node.name+"_cast_out_0"
|
||||
graph.nodes.append(cast_1)
|
||||
graph.outputs[0] = cast_1_out
|
||||
node.outputs[0].dtype = np.float16
|
||||
|
||||
graph.cleanup(True).toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def export_onnx(model: AutoModelForCausalLM):
|
||||
if not os.path.exists(ONNX_MODEL_ORIGIN_PATH):
|
||||
print("exporting origin onnx model...")
|
||||
with torch.no_grad():
|
||||
param = torch.zeros(
|
||||
(args.batchsize, model.config.max_position_embeddings-1), dtype=torch.long)
|
||||
logits = model(param, past_key_values=None)
|
||||
|
||||
if not args.is_1st_graph:
|
||||
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
|
||||
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
|
||||
"position_ids": param_kvcache}), \
|
||||
ONNX_MODEL_ORIGIN_PATH, verbose=False,
|
||||
do_constant_folding=True,)
|
||||
else:
|
||||
position_ids = torch.tile(torch.arange(0, model.config.max_position_embeddings-1), (args.batchsize, 1))
|
||||
attention_mask = torch.ones((args.batchsize, model.config.max_position_embeddings-1), dtype=torch.bool)
|
||||
torch.onnx.export(model, (param, {"attention_mask": attention_mask,
|
||||
"position_ids": position_ids}),\
|
||||
ONNX_MODEL_ORIGIN_PATH, verbose=False,
|
||||
do_constant_folding=True,)
|
||||
print("export origin onnx finished.")
|
||||
|
||||
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SIM_PATH):
|
||||
print("exporting sim onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_ORIGIN_PATH)
|
||||
onnx_model = simplify(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_SIM_PATH, save_as_external_data=True, \
|
||||
location="llama2_sim_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting sim onnx model finished.")
|
||||
|
||||
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_FUSION_PATH):
|
||||
print("exporting fusion onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_SIM_PATH)
|
||||
onnx_model = fusion(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_FUSION_PATH, save_as_external_data=True, \
|
||||
location="llama2_fusion_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting fusion onnx model finished.")
|
||||
|
||||
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SPECIAL_PATH):
|
||||
print("exporting special onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_FUSION_PATH)
|
||||
onnx_model = special_pass(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_SPECIAL_PATH, save_as_external_data=True, \
|
||||
location="llama2_special_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting special onnx model finished.")
|
||||
|
||||
if not args.is_1st_graph and args.fp16 and not os.path.exists(ONNX_MODEL_FP16_PATH):
|
||||
print("exporting fp16 onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_SPECIAL_PATH)
|
||||
onnx_model = convert_to_fp16(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_FP16_PATH, save_as_external_data=True, \
|
||||
location="llama2_fp16_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting fp16 onnx model finished.")
|
||||
|
||||
print("world_size =", args.world_size)
|
||||
if not args.is_1st_graph and args.world_size > 1 and not os.path.exists(ONNX_MODEL_DIST_PATH):
|
||||
print("exporting dist onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_FP16_PATH) if args.fp16 else onnx.load(ONNX_MODEL_SPECIAL_PATH)
|
||||
onnx_model = parallel_model(onnx_model, args.world_size, args.rank)
|
||||
onnx.save(onnx_model, ONNX_MODEL_DIST_PATH, save_as_external_data=True, \
|
||||
location="llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
|
||||
args.batchsize, args.n_layers,
|
||||
16 if args.fp16 else 32, args.world_size, args.rank))
|
||||
print("exporting dist onnx model finished.")
|
||||
|
||||
def get_it_logit(onnx_model, input_ids):
|
||||
# initialization
|
||||
runtime = backend.CudaRuntime(args.rank)
|
||||
runtime.init_comm(
|
||||
"dist",
|
||||
args.world_size,
|
||||
args.rank,
|
||||
)
|
||||
print("[{}] comm init.".format(args.rank))
|
||||
stub = OnnxStub(onnx_model, runtime)
|
||||
print("[{}] stub init.".format(args.rank))
|
||||
|
||||
# warm up
|
||||
for i in range(10):
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
print("[{}] stub warmup.".format(args.rank))
|
||||
|
||||
logits = np.zeros((args.batchsize, args.n_max_length, args.vocab_size), dtype=np.float32)
|
||||
output_ids = np.zeros((args.batchsize, args.n_max_length), dtype=np.int64)
|
||||
avg_inference_time = 0
|
||||
t0 = time.time()
|
||||
for i in tqdm(range(0, args.n_max_length)):
|
||||
with nvtx.annotate("seq_length = {}".format(i), color="red"):
|
||||
assert input_ids.shape[0] == args.batchsize
|
||||
input_id = input_ids[:, i] if i < input_ids.shape[1] else output_ids[:, i-1]
|
||||
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
|
||||
|
||||
# copyin input
|
||||
with nvtx.annotate("[it] copyin", color="blue"):
|
||||
(list(stub.inputs.items()))[0][1].copyin_int64(
|
||||
input_id.reshape(-1).tolist())
|
||||
(list(stub.inputs.items()))[1][1].copyin_int64(
|
||||
position_id.reshape(-1).tolist())
|
||||
|
||||
# run
|
||||
t10 = time.time()
|
||||
with nvtx.annotate("[it] run", color="green"):
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
t11 = time.time()
|
||||
avg_inference_time += (t11 - t10)
|
||||
|
||||
# copyout output
|
||||
if not args.speedup:
|
||||
with nvtx.annotate("[it] copyout", color="blue"):
|
||||
logits[:,i, :] = np.array((list(stub.outputs.items()))[0][1].copyout_float()).reshape(args.batchsize, -1)
|
||||
output_ids[:, i] = np.argmax(logits[:, i, :], -1).astype(np.int64)
|
||||
|
||||
|
||||
t1 = time.time()
|
||||
if args.rank == 0:
|
||||
result = "[it] e2e: {} gpus, {} layers, e2e time: {:.2f}s, average inference time: {:.2f}ms"\
|
||||
.format(args.num_nodes * args.nproc_per_node, args.n_layers, t1-t0, \
|
||||
avg_inference_time*1000/args.n_max_length)
|
||||
print(result)
|
||||
del stub
|
||||
return output_ids
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_model = LlamaForCausalLM.from_pretrained(
|
||||
PRETRAINED_LLAMA_PATH, num_hidden_layers=int(args.n_layers)).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LLAMA_PATH)
|
||||
#prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
#prompt = "What is PTX?"
|
||||
#prompt = "Tell me a joke."
|
||||
#prompt = "What are the key principles of smart investing?"
|
||||
prompt = "What is DeepSpeed?"
|
||||
prompts=[prompt]*args.batchsize
|
||||
inputs = tokenizer(prompts, return_tensors="pt")
|
||||
|
||||
input_ids = inputs.input_ids
|
||||
print("prompt ids =", input_ids)
|
||||
|
||||
##########################################################
|
||||
# inference with InfiniTensor
|
||||
##########################################################
|
||||
print("exporting onnx...")
|
||||
export_onnx(torch_model)
|
||||
print("exporting onnx finished.")
|
||||
|
||||
onnx_to_run_path = ONNX_MODEL_DIST_PATH if args.world_size > 1 else \
|
||||
(ONNX_MODEL_FP16_PATH if args.fp16 else ONNX_MODEL_SPECIAL_PATH)
|
||||
print("loading onnx", onnx_to_run_path, "...")
|
||||
onnx_model = onnx.load(onnx_to_run_path)
|
||||
print("loading onnx finished.")
|
||||
output_ids_it = get_it_logit(onnx_model, input_ids)
|
||||
it_output_text = tokenizer.batch_decode(output_ids_it[:, input_ids.shape[-1]:output_ids_it.shape[-1]])
|
||||
if args.rank == 0:
|
||||
for i in range(args.batchsize):
|
||||
print("prompt: ", prompts[i])
|
||||
print("answer: [it]", it_output_text[i])
|
||||
|
||||
##########################################################
|
||||
# validation with pytorch
|
||||
##########################################################
|
||||
"""
|
||||
generate_ids = torch_model.generate(inputs.input_ids, max_length=args.n_max_length)#, num_beams=4, do_sample=True)
|
||||
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"""
|
||||
if not args.speedup and not args.is_1st_graph:
|
||||
kvcache_torch = None
|
||||
output_ids_pt = torch.zeros(args.batchsize, args.n_max_length).int() # + input_ids.shape[-1] - 1).int()
|
||||
if args.fp16:
|
||||
torch_model = torch_model.half()
|
||||
|
||||
torch_model = torch_model.cuda()
|
||||
# print(torch.cuda.memory_summary())
|
||||
|
||||
avg_inference_time = 0
|
||||
with torch.no_grad():
|
||||
t0 = time.time()
|
||||
for i in range(args.n_max_length):
|
||||
input_id = input_ids[:,i] if i < input_ids.shape[1] else out_token
|
||||
input_id = input_id.view(args.batchsize,1).cuda()
|
||||
t00 = time.time()
|
||||
outputs = torch_model(input_id, past_key_values=kvcache_torch)
|
||||
t01 = time.time()
|
||||
avg_inference_time += (t01-t00)
|
||||
|
||||
logits = outputs['logits']
|
||||
kvcache_torch = outputs['past_key_values']
|
||||
out_token = torch.argmax(logits, dim=-1)
|
||||
output_ids_pt[:, i:i+1] = out_token
|
||||
t1 = time.time()
|
||||
avg_inference_time /= args.n_max_length
|
||||
result = "[pt] e2e time: {:.2f}s, average inference time: {:.2f}ms"\
|
||||
.format(t1-t0, avg_inference_time*1000)
|
||||
|
||||
if args.rank == 0:
|
||||
print(result)
|
||||
pt_output_text = tokenizer.batch_decode(output_ids_pt[:,input_ids.shape[-1]:args.n_max_length])
|
||||
for i in range(args.batchsize):
|
||||
print("[pt]", args.rank, pt_output_text[i])
|
||||
|
||||
if not args.is_1st_graph:
|
||||
assert(output_ids_it.shape[-1] == args.n_max_length)
|
||||
np.testing.assert_equal(output_ids_pt[:, input_ids.shape[-1]:args.n_max_length], output_ids_it[:,input_ids.shape[-1]:args.n_max_length])
|
|
@ -3,14 +3,17 @@
|
|||
#include <cstdio>
|
||||
|
||||
struct AttentionKVCacheMetadata {
|
||||
int dimSize[4];
|
||||
int stride[4];
|
||||
int head_dim;
|
||||
int num_heads;
|
||||
int num_seqs;
|
||||
int max_kv_seqlen;
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||
float *input_q, float *input_k, float *input_v,
|
||||
int *position_id, float *output_matmul,
|
||||
void attention_kvcache_kernel(int dType, void *input_k_cache,
|
||||
void *input_v_cache, void *input_q, void *input_k,
|
||||
void *input_v, int64_t *position_id,
|
||||
void *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta,
|
||||
float *output_O_temp, float *output_sum_temp);
|
||||
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride,
|
||||
int pos_stride);
|
||||
void rope_kernel(int dType, int64_t *pos, void *input, void *output,
|
||||
int dim_model, int dim_head, int batchsize, int pos_stride);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,8 +3,7 @@
|
|||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused Attention with KVCache input operator. All the input and output
|
||||
* tensors should have the same rank except for the position_id.
|
||||
* @brief Fused Attention with KVCache input operator.
|
||||
*
|
||||
*/
|
||||
class AttentionKVCacheObj : public OperatorObj {
|
||||
|
@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj {
|
|||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input_k_cache The k_cache input tensor.
|
||||
* Shape: [batchsize, num_heads, k_cache_seq_length, head_dim]
|
||||
* @param input_v_cache The v_cache input tensor.
|
||||
* Shape: [batchsize, num_heads, v_cache_seq_length, head_dim]
|
||||
* @param input_q The query input tensor.
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
* @param input_k The key input tensor.
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
* @param input_v The value input tensor.
|
||||
* @param position_id The positon id of the query,
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
* @param position_id The positon id of the query.
|
||||
* Shape: [batchsize, q_seq_length]
|
||||
* @param output_matmul The query output tensor.
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
*/
|
||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
|
@ -30,6 +36,10 @@ class AttentionKVCacheObj : public OperatorObj {
|
|||
OP_CLONE(AttentionKVCacheObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||
return {inputs[2]->getDType()};
|
||||
};
|
||||
DataType getDType() const { return getInputs(2)->getDType(); }
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 6; }
|
||||
|
|
|
@ -21,6 +21,10 @@ class RoPEObj : public OperatorObj {
|
|||
int numOutputs() const override { return 1; }
|
||||
DataType getDType() const { return getInputs(1)->getDType(); }
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||
return {inputs[1]->getDType()};
|
||||
};
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
|
|
@ -208,16 +208,39 @@ class OnnxStub:
|
|||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]], # input
|
||||
tensors[node.input[1]], # weight
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
if node.input[1] in data.keys() \
|
||||
and to_array(data[node.input[1]]).dtype == np.float32 \
|
||||
and 'cuda_runtime' in dir(backend) \
|
||||
and tensors[node.input[0]].shape()[0] == 1 \
|
||||
and tensors[node.input[0]].shape()[1] == 1 \
|
||||
and len(tensors[node.input[1]].shape()) == 2 \
|
||||
and node.input[1] in data.keys():
|
||||
data[node.input[1]] = from_array(
|
||||
np.transpose(to_array(data[node.input[1]])))
|
||||
tensors[node.input[1]] = self.handler.tensor(
|
||||
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
|
||||
tensors[node.input[1]].dtype())
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
|
|
|
@ -7,33 +7,37 @@ namespace infini {
|
|||
|
||||
class AttentionKVCacheCompute {
|
||||
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
|
||||
Tensor tensor) const {
|
||||
int nDims = tensor->getRank();
|
||||
auto strides = tensor->getStride();
|
||||
Tensor input_v_cache,
|
||||
Tensor position_id) const {
|
||||
int nDims = input_v_cache->getRank();
|
||||
auto strides = input_v_cache->getStride();
|
||||
IT_ASSERT(nDims == 4);
|
||||
IT_ASSERT(strides.size() == (size_t)nDims);
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
metadata.dimSize[i] = tensor->getDims().at(i);
|
||||
metadata.stride[i] = strides.at(i);
|
||||
int dim_position_id = position_id->getRank();
|
||||
metadata.num_seqs = 1;
|
||||
for (int i = 0; i < dim_position_id; i++) {
|
||||
metadata.num_seqs *= position_id->getDims().at(i);
|
||||
}
|
||||
metadata.head_dim = input_v_cache->getDims().at(3);
|
||||
metadata.num_heads = input_v_cache->getDims().at(1);
|
||||
metadata.max_kv_seqlen = input_v_cache->getDims().at(2);
|
||||
}
|
||||
|
||||
public:
|
||||
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v, Tensor position_id,
|
||||
Tensor output_matmul, CudaPtr p_workspace) const {
|
||||
void do_compute(int dType, Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
Tensor position_id, Tensor output_matmul,
|
||||
CudaPtr p_workspace) const {
|
||||
AttentionKVCacheMetadata metadata;
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache, position_id);
|
||||
|
||||
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
|
||||
input_v_cache->getRawDataPtr<float *>(),
|
||||
input_q->getRawDataPtr<float *>(),
|
||||
input_k->getRawDataPtr<float *>(),
|
||||
input_v->getRawDataPtr<float *>(),
|
||||
position_id->getRawDataPtr<int *>(),
|
||||
output_matmul->getRawDataPtr<float *>(),
|
||||
metadata, (float *)p_workspace,
|
||||
(float *)(p_workspace + (1ll << 30)));
|
||||
attention_kvcache_kernel(
|
||||
dType, input_k_cache->getRawDataPtr<void *>(),
|
||||
input_v_cache->getRawDataPtr<void *>(),
|
||||
input_q->getRawDataPtr<void *>(), input_k->getRawDataPtr<void *>(),
|
||||
input_v->getRawDataPtr<void *>(),
|
||||
position_id->getRawDataPtr<int64_t *>(),
|
||||
output_matmul->getRawDataPtr<void *>(), metadata,
|
||||
(float *)p_workspace, (float *)(p_workspace + (1ll << 30)));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -41,15 +45,17 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
|||
public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
IT_ASSERT(_op->getDType() == DataType::Float32);
|
||||
auto op = as<AttentionKVCacheObj>(_op);
|
||||
int dType = op->getDType().getIndex();
|
||||
int position_idx_dtype = op->getInputs()[5]->getDTypeIndex();
|
||||
IT_ASSERT(dType == 1 || dType == 10 || position_idx_dtype == 7);
|
||||
|
||||
size_t workspaceSize = 2ll << 30;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
||||
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||
_op->getInputs()[2], _op->getInputs()[3],
|
||||
_op->getInputs()[4], _op->getInputs()[5],
|
||||
_op->getOutputs()[0], idxWsData);
|
||||
do_compute(dType, op->getInputs()[0], op->getInputs()[1],
|
||||
op->getInputs()[2], op->getInputs()[3], op->getInputs()[4],
|
||||
op->getInputs()[5], op->getOutputs()[0], idxWsData);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,171 +1,236 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#define WARP_SIZE 32
|
||||
#define BLOCKSIZE WARP_SIZE
|
||||
#define SEQ_UNIT 16
|
||||
#define BLOCKSIZE_2 WARP_SIZE*4
|
||||
#define MAX_PARTITION 1024
|
||||
|
||||
// ASSUME SEQ_LEN OF Q IS 1
|
||||
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||
float* input_v_cache,
|
||||
float* input_q,
|
||||
float* input_k,
|
||||
float* input_v,
|
||||
int* position_id,
|
||||
template <class T>
|
||||
__global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
|
||||
T* input_v_cache,
|
||||
T* input_q,
|
||||
T* input_k,
|
||||
T* input_v,
|
||||
int64_t* position_id,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int seq_length = position_id[0] + 1;
|
||||
int seq_length = position_id[blockIdx.y] + 1;
|
||||
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
||||
if(blockIdx.y >= stride)
|
||||
if(blockIdx.z >= stride)
|
||||
return;
|
||||
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
int idx_seq = blockIdx.y * SEQ_UNIT;
|
||||
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
|
||||
int parallel_idx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
|
||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||
return;
|
||||
int idx_seq = blockIdx.z * SEQ_UNIT;
|
||||
|
||||
float ptr_V[SEQ_UNIT*4]; // V
|
||||
float ptr_K[SEQ_UNIT*4]; // K
|
||||
float ptr_Q[4]; // Q
|
||||
float ptr_P[SEQ_UNIT] = {0};
|
||||
half reg_V[4];
|
||||
half reg_K[4];
|
||||
half reg_Q[4];
|
||||
float reg_P;
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_sum[1] = {0};
|
||||
float reg_O[4] = {0};
|
||||
float reg_sum = 0;
|
||||
float temp[4];
|
||||
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||
|
||||
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.head_dim;
|
||||
|
||||
// readin Q
|
||||
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
|
||||
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
|
||||
|
||||
// Q*K
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||
}
|
||||
else{
|
||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
||||
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
|
||||
}
|
||||
|
||||
|
||||
if(!is_fp16){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++){
|
||||
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
|
||||
}
|
||||
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(®_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
}
|
||||
}
|
||||
|
||||
// div sqrt(d)
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
||||
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
|
||||
else{
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(half2 &)reg_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||
}
|
||||
}
|
||||
|
||||
// softmax
|
||||
int common_idx = lane_id_x2 + (parallel_idx * compMeta.max_kv_seqlen * compMeta.head_dim);
|
||||
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
|
||||
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
||||
}
|
||||
|
||||
// * V
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||
reg_P = 0;
|
||||
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.head_dim);
|
||||
// readin K & V
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(®_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||
*((half2*)(®_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||
}
|
||||
}
|
||||
else{
|
||||
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
||||
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
||||
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
|
||||
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
|
||||
if(!is_fp16){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(®_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_K[i]));
|
||||
(float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(®_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_V[i]));
|
||||
}
|
||||
}
|
||||
else{
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(half2 &)reg_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_K[i]));
|
||||
(half2 &)reg_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_V[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Q*K
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i += 2){
|
||||
(half2 &)reg_K[i] = (half2 &)reg_Q[i] * (half2 &)reg_K[i];
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
|
||||
(half2 &)reg_K[i] += __shfl_xor_sync(0xffffffff, (half2 &)reg_K[i], offset);
|
||||
}
|
||||
(float2 &) temp[i] = __half22float2((half2 &)reg_K[i]);
|
||||
reg_P += (temp[i] + temp[i+1]);
|
||||
(float2 &) temp[i] = __half22float2((half2 &)reg_V[i]);
|
||||
}
|
||||
|
||||
// div sqrt(d)
|
||||
reg_P /= sqrt(128.0);
|
||||
|
||||
// softmax
|
||||
reg_P = expf(reg_P);
|
||||
reg_sum += reg_P;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
|
||||
reg_O[i] = fmaf(reg_P, temp[i], reg_O[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] /= ptr_sum[0];
|
||||
reg_O[i] /= reg_sum;
|
||||
|
||||
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
|
||||
if(lane_id == 0){
|
||||
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.z * compMeta.head_dim) + (parallel_idx * compMeta.head_dim * stride)] = __float22half2_rn((float2 &)reg_O[i]);
|
||||
if(lane_id_x2 == 0){
|
||||
output_sum_temp[blockIdx.z + parallel_idx * stride] = reg_sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
||||
float* output_matmul,
|
||||
|
||||
template <class T>
|
||||
__global__ void _attention_kvcache_kernel_128_2(int64_t* position_id,
|
||||
T* output_matmul,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
int parallel_idx = blockIdx.x;
|
||||
int offset = parallel_idx * compMeta.head_dim;
|
||||
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_O_sum[4] = {0};
|
||||
float ptr_sum = 0;
|
||||
float ptr_sum_temp;
|
||||
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||
|
||||
if(size == 1){
|
||||
if(!is_fp16){
|
||||
#pragma unroll
|
||||
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
|
||||
output_matmul[i + offset]
|
||||
= __half2float(output_O_temp[i + offset]);
|
||||
}
|
||||
else{
|
||||
#pragma unroll
|
||||
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
|
||||
output_matmul[i + offset]
|
||||
= output_O_temp[i + offset];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
__shared__ float shm_sum_temp[MAX_PARTITION];
|
||||
__shared__ float shm_sum[WARP_SIZE];
|
||||
float temp_sum = 0;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float4 &)ptr_O[0]
|
||||
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
||||
ptr_sum += ptr_sum_temp;
|
||||
for(int i = threadIdx.x; i < size; i += blockDim.x){
|
||||
shm_sum_temp[i] = output_sum_temp[i + parallel_idx * size];
|
||||
temp_sum += shm_sum_temp[i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
|
||||
temp_sum += __shfl_down_sync(0xffffffff, temp_sum, offset);
|
||||
if(lane_id == 0)
|
||||
shm_sum[threadIdx.x/WARP_SIZE] = temp_sum;
|
||||
__syncthreads();
|
||||
temp_sum = lane_id < (size + WARP_SIZE - 1) / WARP_SIZE ? shm_sum[lane_id] : 0;
|
||||
|
||||
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
#pragma unroll
|
||||
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
|
||||
temp_sum += __shfl_xor_sync(0xffffffff, temp_sum, offset);
|
||||
temp_sum = __fdividef(1.0f, temp_sum + 1e-6f);
|
||||
|
||||
#pragma unroll
|
||||
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x){
|
||||
float acc = 0.0f;
|
||||
for(int j = 0; j < size; j ++){
|
||||
acc = fma(__half2float(output_O_temp[i + (j * compMeta.head_dim) + offset * size]) * shm_sum_temp[j], temp_sum, acc);
|
||||
}
|
||||
|
||||
if(!is_fp16){
|
||||
output_matmul[i + offset] = acc;
|
||||
}
|
||||
else{
|
||||
output_matmul[i + offset] = __float2half(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
namespace infini {
|
||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||
float *input_q, float *input_k,
|
||||
float *input_v, int *position_id, float *output_matmul,
|
||||
void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cache,
|
||||
void *input_q, void *input_k,
|
||||
void *input_v, int64_t *position_id, void *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta,
|
||||
float *output_O_temp, float *output_sum_temp) {
|
||||
IT_ASSERT(compMeta.dimSize[3] == 128);
|
||||
IT_ASSERT(dType == 1 || dType == 10);
|
||||
|
||||
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
|
||||
dim3 blockDim(BLOCKSIZE, 1);
|
||||
int gridsize_y = (compMeta.max_kv_seqlen - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||
dim3 gridDim(compMeta.num_heads, compMeta.num_seqs, gridsize_y);
|
||||
dim3 blockDim(WARP_SIZE, 1);
|
||||
|
||||
_attention_kvcache_kernel_128_1
|
||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
||||
compMeta, output_O_temp, output_sum_temp);
|
||||
if(dType == 1){
|
||||
_attention_kvcache_kernel_128_1<float>
|
||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||
((float*)input_k_cache, (float*)input_v_cache, (float*)input_q, (float*)input_k, (float*)input_v,
|
||||
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2<float>
|
||||
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, (float*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
}
|
||||
else{
|
||||
_attention_kvcache_kernel_128_1<half>
|
||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||
((half*)input_k_cache, (half*)input_v_cache, (half*)input_q, (half*)input_k, (half*)input_v,
|
||||
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2<half>
|
||||
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
||||
_attention_kvcache_kernel_128_2
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -36,7 +36,7 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
|
||||
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
||||
if (cuDataType == CUDA_R_16F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
} else if (cuDataType == CUDA_R_16BF) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (cuDataType == CUDA_R_32F) {
|
||||
|
|
|
@ -18,17 +18,18 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
|||
const auto &inputShape = input->getDims();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
int size = input->size();
|
||||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||
IT_ASSERT(inputShape[0] == pos->getDims()[0] &&
|
||||
inputShape[1] == pos->getDims()[1]);
|
||||
int position_idx_dtype = op->getInputs()[0]->getDTypeIndex();
|
||||
int dim_model = inputShape[2];
|
||||
int dim_head = 128;
|
||||
int hidden_stride = dim_model * inputShape[1];
|
||||
int dim_head = 128; // TODO: get dim_head from the framework
|
||||
int pos_stride = inputShape[1];
|
||||
int batchsize = inputShape[0];
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
|
||||
size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
rope_kernel(dType, pos->getRawDataPtr<int64_t *>(), inputData,
|
||||
outputData, dim_model, dim_head, batchsize, pos_stride);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -4,13 +4,15 @@
|
|||
#include "utils/small_array.h"
|
||||
|
||||
template <class T>
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
||||
int dim_head, int hidden_stride, int pos_stride) {
|
||||
__global__ void _rope_kernel(int64_t* pos, void *in, void *out, int dim_model,
|
||||
int dim_head, int batchsize, int pos_stride) {
|
||||
int batch_id = blockIdx.x;
|
||||
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
|
||||
|
||||
int ith = blockIdx.z * blockDim.x + threadIdx.x;
|
||||
int col = ith % dim_head;
|
||||
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
|
||||
int batch_stride = pos_stride * dim_model;
|
||||
int offset = batch_id * batch_stride + blockIdx.y * dim_model;
|
||||
|
||||
if (ith >= dim_model)
|
||||
return;
|
||||
|
@ -34,7 +36,7 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
#define CASE(T) \
|
||||
_rope_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
(pos, input, output, dim_model, dim_head, batchsize, pos_stride);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
@ -79,10 +81,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
void rope_kernel(int dType, int64_t * pos, void *input, void *output,
|
||||
int dim_model, int dim_head, int batchsize, int pos_stride) {
|
||||
dim3 blocksize = dim3(32,1,1);
|
||||
dim3 gridsize = dim3(1, 1, dim_model/32);
|
||||
dim3 gridsize = dim3(batchsize, pos_stride, dim_model/32);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
|
|
|
@ -26,11 +26,12 @@ TEST(RoPE, Cuda) {
|
|||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
||||
oCpu->printData();
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773}));
|
||||
0.540302, 0.647906, 0.731761, 0.796458, 0.846009, 0.883756, 0.912396,
|
||||
0.934062, 0.950415, 0.962739, 0.972014, 0.978989, 0.98423, 0.988167,
|
||||
0.991122, 0.99334, 0.995004, 0.996253, 0.99719, 0.997892, 0.998419,
|
||||
0.998815, 0.999111, 0.999333, 0.9995, 0.999625, 0.999719, 0.999789,
|
||||
0.999842, 0.999881, 0.999911, 0.999933}));
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue