Compare commits

...

17 Commits

Author SHA1 Message Date
xiaonans 4a5b9572bb add test scripts for llama2 and 9G models 2024-04-10 16:23:02 +08:00
xiaonans 159642d6ae merge master 2024-04-10 10:03:11 +08:00
xiaonans c01e64db50 rope and attention ops support multiple batchs/sequences. 2024-04-09 09:16:42 +08:00
xiaonans eb3a2d123d accelerate cuda attention 2024-03-28 09:07:30 +08:00
xiaonans 4bdd33522b accelerate cuda fp32 matmul 2024-03-26 11:37:54 +08:00
xiaonans 0740d26f43 clean up 2024-03-21 10:17:06 +08:00
xiaonans fc3d38f80e attention support fp16 2024-03-20 14:56:15 +08:00
xiaonans d43364ac60 inter-block communication is fp16 2024-03-19 11:21:14 +08:00
xiaonans db053e32a4 kv register is fp16 2024-03-18 17:25:57 +08:00
xiaonans 1e797d4ffe cache is fp16 2024-03-18 15:51:19 +08:00
xiaonans 80412ae162 fix bugs when blocksize==64 2024-03-18 15:31:52 +08:00
xiaonans 83be7fa373 fix bugs in rmsnorm op 2024-02-20 10:59:53 +08:00
xiaonans 0f1c04d864 add fp16 support to silu cuda op 2024-02-19 11:39:21 +08:00
xiaonans 936797b960 support rmsnorm 2024-02-08 14:58:47 +08:00
xiaonans 17bd98d453 modify rope op 2024-02-06 17:04:05 +08:00
xiaonans 8cc6af0a83 modify code to pass the cuda_all_reduce test 2024-02-06 10:53:32 +08:00
xiaonans c04910f118 [feature] add cudagraph support 2024-02-05 16:19:58 +08:00
13 changed files with 1281 additions and 169 deletions

@ -1 +1 @@
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98

512
examples/python/test_9G.py Normal file
View File

@ -0,0 +1,512 @@
import os
from pyinfinitensor.onnx import OnnxStub, backend
import numpy as np
import onnx
import torch
from tqdm import tqdm
import onnx_graphsurgeon as gs
import time
import nvtx
import argparse
from mpi4py import MPI
from pytrie import StringTrie
import io
import json
import re
from typing import (
Dict,
List,
IO,
)
parser = argparse.ArgumentParser(description='')
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
parser.add_argument('--layer', dest='n_layers', type=int, default=48)
parser.add_argument("--num_nodes", dest='num_nodes',
type=int, default=1, help="number of nodes")
parser.add_argument("--world_size", dest="world_size",
type=int, default=1, help="")
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
type=int, default=1, help="number of processes per node")
parser.add_argument("--n_max_length", dest="n_max_length",
type=int, default=1024, help="number of processes per node")
parser.add_argument("--vocab_size", dest="vocab_size",
type=int, default=119696, help="vocabulary size")
parser.add_argument("--hidden_size", dest="hidden_size",
type=int, default=4096, help="vocabulary size")
parser.add_argument('--rank', dest='rank', type=int, default=0)
parser.add_argument('--speedup', action='store_true')
parser.add_argument('--no_cudagraph', action='store_true')
parser.add_argument('--fp16', action='store_true')
args = parser.parse_args()
comm = MPI.COMM_WORLD
args.rank = comm.Get_rank()
args.nproc_per_node = comm.Get_size()
args.world_size = args.num_nodes * args.nproc_per_node
ONNX_MODEL_PATH = "/data3/shared/xnsong/9G/dist/9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
weight_path = "9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
model_dir = "/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/cpm9g-11b-sft.pt"
@gs.Graph.register()
def RMSNorm(self, a, b):
return self.layer(op="RMSNorm", inputs=a, outputs=b)
@gs.Graph.register()
def RoPE(self, a, b):
return self.layer(op="RoPE", inputs=a, outputs=b)
@gs.Graph.register()
def AttentionKVCache(self, a, b):
return self.layer(op="AttentionKVCache", inputs=a, outputs=b)
def to_numpy(dict):
ret = dict
if args.fp16:
ret = np.float16(ret)
else:
ret = np.float32(ret)
return ret
def parallel(array, split='replicate'):
if args.world_size > 1 and split == 'partial_column':
return np.hsplit(array, args.world_size)[args.rank]
elif args.world_size > 1 and split == 'partial_row':
return np.vsplit(array, args.world_size)[args.rank]
return array
def generate_onnx(ONNX_MODEL_PATH):
state_dict = torch.load(f'{model_dir}', map_location='cpu')
new_state_dict = {name: param.cpu().numpy()
for name, param in state_dict.items()
}
operators = []
graph = gs.Graph(nodes=operators)
gather_input = gs.Variable(name="gather_input.0", dtype=np.int64, shape=(1,1))
pos_input = gs.Variable(name="pos_input.0", dtype=np.int64, shape=(1,1))
embedding_weight = gs.Constant(name="embedding.weight", values=to_numpy(new_state_dict["input_embedding.weight"]))
gather_output = gs.Variable(name="gather_output.0")
gather = gs.Node(op="Gather", inputs=[embedding_weight, gather_input], outputs=[gather_output])
operators.append(gather)
input = gather_output
graph.inputs=[gather_input, pos_input]
graph.outputs=[]
for i in tqdm(range(args.n_layers)):
# global input
attn_kcache_input = gs.Variable(name="/layers." + str(i) + "/attn/kcache_input", dtype=np.float32, shape=(1,32,1023,128))
attn_vcache_input = gs.Variable(name="/layers." + str(i) + "/attn/vcache_input", dtype=np.float32, shape=(1,32,1023,128))
graph.inputs.append(attn_kcache_input)
graph.inputs.append(attn_vcache_input)
# weight
layernorm_0_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.0/mul_weight",
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".self_att.layernorm_before_attention.weight"]))
attn_qproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/qproj_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_q.weight"]))
, 'partial_column'))
attn_kproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/kproj_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_k.weight"]))
, 'partial_column'))
attn_vproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/vproj_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_v.weight"]))
, 'partial_column'))
attn_outmatmul_input = gs.Constant(name="/layers." + str(i) + "/attn/outmatmul_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.attention_out.weight"]))
, 'partial_row'))
layernorm_1_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.1/mul_weight",
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".ffn.layernorm_before_ffn.weight"]))
ffn_matmul_0_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_0_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_0.weight"]))
, 'partial_column'))
ffn_matmul_1_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_1_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_1.weight"]))
, 'partial_column'))
ffn_matmul_out_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_out_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_out.weight"]))
, 'partial_row'))
attn_qrope_output = gs.Variable(name="/layers." + str(i) + "/attn/qrope_output")
attn_krope_output = gs.Variable(name="/layers." + str(i) + "/attn/krope_output")
attn_kvcache_output = gs.Variable(name="/layers." + str(i) + "/attn/kvcache_output")
layernorm_0_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.0/mul_output_1")
layernorm_1_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.1/mul_output_1")
attn_qproj_output = gs.Variable(name="/layers." + str(i) + "/attn/qproj_output")
attn_kproj_output = gs.Variable(name="/layers." + str(i) + "/attn/kproj_output")
attn_vproj_output = gs.Variable(name="/layers." + str(i) + "/attn/vproj_output")
attn_outmatmul_output = gs.Variable(name="/layers." + str(i) + "/attn/outmatmul_output")
attn_outadd_output = gs.Variable(name="/layers." + str(i) + "/attn/outadd_output")
ffn_matmul_0_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_0_output")
ffn_silu_output = gs.Variable(name="/layers." + str(i) + "/ffn/silu_output")
ffn_matmul_1_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_1_output")
ffn_mul_output = gs.Variable(name="/layers." + str(i) + "/ffn/mul_output")
ffn_matmul_out_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_out_output")
ffn_add_output = gs.Variable(name="/layers." + str(i) + "/ffn/add_output")
graph.RMSNorm([input, layernorm_0_mul_weight], [layernorm_0_mul_output_1])
attn_qproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_qproj_weight], outputs=[attn_qproj_output])
operators.append(attn_qproj)
attn_kproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_kproj_weight], outputs=[attn_kproj_output])
operators.append(attn_kproj)
attn_vproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_vproj_weight], outputs=[attn_vproj_output])
operators.append(attn_vproj)
graph.RoPE([pos_input, attn_qproj_output], [attn_qrope_output])
graph.RoPE([pos_input, attn_kproj_output], [attn_krope_output])
graph.AttentionKVCache([attn_kcache_input, attn_vcache_input, attn_qrope_output, attn_krope_output, attn_vproj_output, pos_input],[attn_kvcache_output])
attn_outproj = gs.Node(op="MatMul", inputs=[attn_kvcache_output, attn_outmatmul_input], outputs=[attn_outmatmul_output])
operators.append(attn_outproj)
attn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/attn/reducesum_output")
if args.world_size > 1:
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/attn/reducesum",
inputs=[attn_outmatmul_output], outputs=[attn_reduce_sum_output],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum)
attn_outadd = gs.Node(op="Add", inputs=[input, attn_outmatmul_output if args.world_size == 1 else attn_reduce_sum_output], outputs=[attn_outadd_output])
operators.append(attn_outadd)
graph.RMSNorm([attn_outadd_output, layernorm_1_mul_weight], [layernorm_1_mul_output_1])
ffn_matmul_0 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_0_input], outputs=[ffn_matmul_0_output])
operators.append(ffn_matmul_0)
ffn_silu = gs.Node(op="Silu", inputs=[ffn_matmul_0_output], outputs=[ffn_silu_output])
operators.append(ffn_silu)
ffn_matmul_1 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_1_input], outputs=[ffn_matmul_1_output])
operators.append(ffn_matmul_1)
ffn_mul = gs.Node(op="Mul", inputs=[ffn_silu_output, ffn_matmul_1_output], outputs=[ffn_mul_output])
operators.append(ffn_mul)
ffn_matmul_out = gs.Node(op="MatMul", inputs=[ffn_mul_output, ffn_matmul_out_input], outputs=[ffn_matmul_out_output])
operators.append(ffn_matmul_out)
ffn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/ffn/reducesum_output")
if args.world_size > 1:
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/ffn/reducesum",
inputs=[ffn_matmul_out_output], outputs=[ffn_reduce_sum_output],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum)
ffn_add = gs.Node(op="Add", inputs=[attn_outadd_output, ffn_matmul_out_output if args.world_size == 1 else ffn_reduce_sum_output], outputs=[ffn_add_output])
operators.append(ffn_add)
input = ffn_add_output
layernorm_mul_weight = gs.Constant(name="/output/layernorm/mul_weight", values=to_numpy(new_state_dict["encoder.output_layernorm.weight"]))
layernorm_mul_output_1 = gs.Variable(name="/output/layernorm/mul_output_1")
graph.RMSNorm([input, layernorm_mul_weight], [layernorm_mul_output_1])
lm_head_weight = gs.Constant(name="/output/lm_head/weight", values=np.transpose(to_numpy(new_state_dict["lm_head.weight"])))
lm_head_output = gs.Variable(name="/output/lm_head/output")
lm_head = gs.Node(op="MatMul", inputs=[layernorm_mul_output_1, lm_head_weight], outputs=[lm_head_output])
operators.append(lm_head)
if args.fp16:
final_cast_output = gs.Variable(name="/output/cast/output", dtype=np.float32, shape=(1,1,args.vocab_size))
final_cast = gs.Node(op="Cast", inputs=[lm_head_output], outputs=[final_cast_output])
final_cast.attrs["to"] = np.float32
operators.append(final_cast)
graph.outputs.append(final_cast_output)
else:
lm_head_output.dtype=np.float32
lm_head_output.shape=(1,1,args.vocab_size)
graph.outputs.append(lm_head_output)
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True, location=weight_path)
return
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
"""Loads a vocabulary file into a dictionary."""
vocab: Dict[str, int] = {}
reader = io.TextIOWrapper(fp, encoding="utf-8")
for token in reader.readlines():
token = token.strip()
if len(token) == 0:
continue
token = json.loads(token)
vocab[token] = len(vocab)
return vocab
class CPM9GTokenizer(object):
def __init__(self, path):
self.unk_token = "<unk>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
]
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
all_tokens = load_vocab(io.FileIO(path, "rb"))
self.encoder: Dict[str, int] = {}
self._special_encoder: Dict[str, int] = {}
for token, token_id in all_tokens.items():
if token in self._special_token_set:
self._special_encoder[token] = token_id
else:
self.encoder[token] = token_id
self.decoder = {v: k for k, v in self.encoder.items()}
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
self._max_word_len = max([len(x) for x in self.encoder.keys()])
self._len_word_first = {}
for x in self.encoder.keys():
if not x[0] in self._len_word_first:
self._len_word_first[x[0]] = 1
if len(x) > self._len_word_first[x[0]]:
self._len_word_first[x[0]] = len(x)
self.tencoder = StringTrie(self.encoder)
def get_piece(self, text: str) -> str:
if text[0] in self._len_word_first:
text = text[: self._len_word_first[text[0]]]
len_text = len(text)
for i in range(len(text)):
sub = text[: len_text - i]
if sub in self.encoder:
return sub
return text[0]
@property
def vocab_size(self):
return len(self)
@property
def eos_id(self):
return self._special_encoder[self.eos_token]
@property
def bos_id(self):
return self._special_encoder[self.bos_token]
@property
def unk_id(self):
return self._special_encoder[self.unk_token]
def __len__(self):
return len(self.encoder) + len(self._special_encoder)
def tokenize(self, text: str) -> List[str]:
output_tokens: List[str] = []
st = 0
while st < len(text):
piece = self.get_piece(text[st:])
output_tokens.append(piece)
st += len(piece)
return output_tokens
@staticmethod
def escape(text: str) -> str:
return text
@staticmethod
def unescape(text: str) -> str:
return text
def encode(self, text: str, with_bos = True) -> List[int]:
ret = []
if with_bos:
ret.append(self.bos_id)
for x in self.tokenize(text):
if x in self.encoder:
ret.append(self.encoder[x])
else:
ret.extend(self._encode_unicode(x))
return ret
def decode(self, tokens: List[int]):
"""Decode ids into a string."""
ret = []
st = 0
while st < len(tokens):
if tokens[st] in self.decoder:
ret.append(self.decoder[tokens[st]])
st += 1
elif tokens[st] in self._byte_decoder:
first = self._byte_decoder[tokens[st]]
length = 1 if first < 128 else len(re.search('^1+0', bin(first)[2:])[0])-1
code = 0
try:
for j in range(length):
code = code << 8 | self._byte_decoder[tokens[st + j]]
code = int.to_bytes(code, length, "big").decode("utf-8")
ret.append(code)
except:
pass
st = st + length
elif tokens[st] == self.eos_id:
ret.append(self.eos_token)
st += 1
elif tokens[st] == self.bos_id:
ret.append(self.bos_token)
st += 1
else:
ret.append(self.unk_token)
st += 1
return "".join(ret)
def _encode_unicode(self, token):
# wrap unicode encoding into a helper function
ids = []
utf8_id = token.encode("utf-8")
for _id in utf8_id:
ids.append(self._special_encoder[self.byte_list[_id]])
return ids
def next_token(self, text):
# fast next token matching
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
if token is None:
token = text[0]
token_ids = self._encode_unicode(token)
else:
token_ids = [token_id]
return token, token_ids
def start_worker(
world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, query
):
model = onnx.load(ONNX_MODEL_PATH)
runtime = backend.CudaRuntime(local_rank)
if args.nproc_per_node > 1:
runtime.init_comm(
"9g",
world_size,
rank,
)
print("[{}] comm init.".format(rank))
stub = OnnxStub(model, runtime)
print("[{}] stub init.".format(rank))
for i in range(10):
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
print("[{}] stub warmup.".format(rank))
tokenizer = CPM9GTokenizer("/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/vocabs.txt")
query = tokenizer.encode(query)
output_tokens = []
for i in range(len(query)):
q = np.array(query[i])
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
if i == len(query) - 1:
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
q = np.argmax(output)
output_tokens.append(q)
avg_time = 0
count = 0
while i < 1000:
count = count + 1
torch.cuda.synchronize()
with nvtx.annotate("gen {}-th token".format(i), color="red"):
i = i + 1
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
t0 = time.time()
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
t1 = time.time()
avg_time += t1 - t0
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
# print(output)
with nvtx.annotate("argmax".format(i), color="green"):
q = np.argmax(output)
if q == 2:
break
output_tokens.append(q)
avg_time = avg_time / count
print("avg_time_cost =", avg_time*1000, "ms")
text = tokenizer.decode(output_tokens)
return text
if __name__ == "__main__":
comm = MPI.COMM_WORLD
args.rank = comm.Get_rank()
args.nproc_per_node = comm.Get_size()
world_size = args.num_nodes * args.nproc_per_node
if not os.path.exists(ONNX_MODEL_PATH):
print("exporting onnx graph")
generate_onnx(ONNX_MODEL_PATH)
else:
print("will use exsiting onnx graph")
onnx_model = onnx.load(ONNX_MODEL_PATH)
print("data loaded")
#query = '''Beijing is the captial'''
#query = '''什么是PTX'''
#query = '''生病了怎么办?'''
#query = '''Happy'''
query = '''def gcd(a, b):'''
####################################
# infinitensor dist
####################################
# run distributed parallel.
pred = start_worker(world_size, args.rank, args.rank %
args.nproc_per_node, onnx_model, query)
if args.rank == 0:
print("输入:\n\n", query, "\n")
print("输出:", pred)

View File

@ -0,0 +1,491 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from tqdm import tqdm
import argparse
import torch
import onnx
import onnx_graphsurgeon as gs
import os
import numpy as np
from pyinfinitensor.onnx import OnnxStub, backend
import time
import nvtx
from mpi4py import MPI
parser = argparse.ArgumentParser(description='')
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
parser.add_argument('--layer', dest='n_layers', type=int, default=32)
parser.add_argument("--num_nodes", dest='num_nodes',
type=int, default=1, help="number of nodes")
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
type=int, default=1, help="number of processes per node")
parser.add_argument("--world_size", dest="world_size",
type=int, default=1, help="")
parser.add_argument("--n_max_length", dest="n_max_length",
type=int, default=1024, help="")
parser.add_argument("--vocab_size", dest="vocab_size",
type=int, default=32000, help="vocabulary size")
parser.add_argument("--hidden_size", dest="hidden_size",
type=int, default=4096)
parser.add_argument("--head_size", dest="head_size",
type=int, default=32)
parser.add_argument("--head_dim", dest="head_dim",
type=int, default=128)
parser.add_argument('--rank', dest='rank', type=int, default=0)
parser.add_argument('--no_cudagraph', action='store_true')
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--is_1st_graph', action='store_true')
parser.add_argument('--speedup', action='store_true')
args = parser.parse_args()
comm = MPI.COMM_WORLD
args.rank = comm.Get_rank()
args.nproc_per_node = comm.Get_size()
args.world_size = args.num_nodes * args.nproc_per_node
PRETRAINED_LLAMA_PATH = "/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/"
ONNX_MODEL_PATH = "/data3/shared/xnsong/llama2/" + ("1st" if args.is_1st_graph else "2nd")
ONNX_MODEL_ORIGIN_PATH = ONNX_MODEL_PATH + "/origin/llama2_origin_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_SIM_PATH = ONNX_MODEL_PATH + "/sim/llama2_sim_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_FUSION_PATH = ONNX_MODEL_PATH + "/fusion/llama2_fusion_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_SPECIAL_PATH = ONNX_MODEL_PATH + "/special/llama2_special_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_FP16_PATH = ONNX_MODEL_PATH + "/fp16/llama2_fp16_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_DIST_PATH = ONNX_MODEL_PATH + "/dist/llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
def parallel_model(onnx_model, world_size, rank):
graph = gs.import_onnx(onnx_model)
tmap = graph.tensors()
for i in range(args.n_layers):
tmap[graph.inputs[2+i*2].name].shape[1] = tmap[graph.inputs[2+i*2].name].shape[1]//world_size
tmap[graph.inputs[3+i*2].name].shape[1] = tmap[graph.inputs[3+i*2].name].shape[1]//world_size
for node in graph.nodes:
if node.name == "/model/layers." + str(i) + "/self_attn/q_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/self_attn/k_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/self_attn/v_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/self_attn/o_proj/MatMul":
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
reduce_sum_output = gs.Variable("reduce_sum_output_" + str(i) + "_0",
dtype=np.float32)
reduce_sum = gs.Node(op="ReduceSum", name="reduce_sum_"+str(i)+"_0",
inputs=node.outputs, outputs=[reduce_sum_output],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum)
next_node = node.outputs[0].outputs[0]
next_node.inputs[1] = reduce_sum_output
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_0" or \
node.name == "/model/layers." + str(i) + "/self_attn/Reshape_1":
node.inputs[1].values = np.array(
[1, 1,
args.head_size//world_size,
args.hidden_size//args.head_size])
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_2":
node.inputs[1] = gs.Constant(name="/model/layers."+str(i)+"/self_attn/vreshape_input",
values=np.array(
[1, 1,
args.head_size//world_size,
args.hidden_size//args.head_size]))
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_3":
node.inputs[1] = gs.Constant(name="/model/layers." + str(i) + "/self_attn/Reshape_3_shape",
values=np.array(
[1, 1, args.hidden_size//world_size]))
elif node.name == "/model/layers." + str(i) + "/mlp/up_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/mlp/gate_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/mlp/down_proj/MatMul":
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
reduce_sum_output_1 = gs.Variable("reduce_sum_output_" + str(i) + "_1",
dtype=np.float32)
reduce_sum_1 = gs.Node(op="ReduceSum", inputs=node.outputs, outputs=[reduce_sum_output_1],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum_1)
next_node = node.outputs[0].outputs[0]
next_node.inputs[1] = reduce_sum_output_1
# new_out_1 = tmap["/model/layers.0/mlp/down_proj/MatMul_output_0"] #reduce_sum_output
# new_out_1.dtype = np.float32
# new_out_1.shape = [1,1,4096]
# graph.outputs.append(new_out_1)
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def simplify(onnx_model):
graph = gs.import_onnx(onnx_model)
for node in graph.nodes:
if node.op == "Cast":
inp_node = node.i()
inp_node.outputs = node.outputs
node.outputs.clear()
for i in range(args.n_layers):
nodename = "/model/layers." + str(i) + "/self_attn/Add_2"
node = [node for node in graph.nodes if node.name == nodename][0]
inp_node = node.i()
inp_node.outputs = node.outputs
node.outputs.clear()
graph.cleanup().toposort()
return gs.export_onnx(graph)
@gs.Graph.register()
def replace_with_RMSNorm(self, inputs, outputs):
inputs[0].outputs.pop(0)
inputs[0].outputs.pop(0)
for out in outputs:
out.inputs.clear()
return self.layer(op="RMSNorm", inputs=inputs, outputs=outputs, name="rmsnorm")
@gs.Graph.register()
def replace_with_silu(self, inputs, outputs):
for inp in inputs:
inp.outputs.clear()
for out in outputs:
out.inputs.clear()
return self.layer(op="Silu", inputs=inputs, outputs=outputs, name="silu")
@gs.Graph.register()
def replace_with_RoPE(self, a, b):
return self.layer(op="RoPE", inputs=a, outputs=b, name="rope")
@gs.Graph.register()
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
for inp in inputs:
inp.outputs.clear()
for out in outputs:
out.inputs.clear()
for inp in inputs_added:
inputs.append(inp)
for out in outputs_removed:
out.inputs.clear()
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs, name="attention")
def fusion(model):
graph = gs.import_onnx(model)
tmap = graph.tensors()
tmap["onnx::Reshape_1"].outputs.clear()
inputs = [tmap["/model/layers.0/input_layernorm/Cast_output_0"], tmap["model.layers.0.input_layernorm.weight"]]
rmsnorm_outputs = [tmap["/model/layers.0/input_layernorm/Mul_1_output_0"]]
graph.replace_with_RMSNorm(inputs, rmsnorm_outputs)
for i in range(args.n_layers):
# rotary embedding op
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"].inputs.clear()
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"].inputs.clear()
attn_qreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/qreshape_input",
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
attn_kreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/kreshape_input",
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
attn_qrope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qrope_output")
attn_krope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/krope_output")
attn_qreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qreshape_output")
attn_kreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/kreshape_output")
attn_qreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_0", inputs=[attn_qrope_output, attn_qreshape_input], outputs=[attn_qreshape_output])
attn_kreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_1", inputs=[attn_krope_output, attn_kreshape_input], outputs=[attn_kreshape_output])
attn_qtrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_qreshape_output],
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"]])
attn_ktrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_kreshape_output],
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"]])
graph.nodes.append(attn_qreshape)
graph.nodes.append(attn_kreshape)
graph.nodes.append(attn_qtrans)
graph.nodes.append(attn_ktrans)
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/q_proj/MatMul_output_0"]]
graph.replace_with_RoPE(inputs, [attn_qrope_output])
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/k_proj/MatMul_output_0"]]
graph.replace_with_RoPE(inputs, [attn_krope_output])
# rms-norm op
inputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Cast_output_0"], \
tmap["model.layers." + str(i) + ".post_attention_layernorm.weight"]]
outputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Mul_1_output_0"]]
graph.replace_with_RMSNorm(inputs, outputs)
inputs = [tmap["/model/layers." + str(i+1) + "/input_layernorm/Cast_output_0"] if i != args.n_layers-1 else \
tmap["/model/norm/Cast_output_0"], \
tmap["model.layers." + str(i+1) + ".input_layernorm.weight"] if i != args.n_layers-1 else \
tmap["model.norm.weight"]]
outputs = [tmap["/model/layers."+ str(i+1) + "/input_layernorm/Mul_1_output_0"]] if i != args.n_layers-1 else \
[tmap["/model/norm/Mul_1_output_0"]]
graph.replace_with_RMSNorm(inputs, outputs)
# silu op
inputs = [tmap["/model/layers." + str(i) + "/mlp/gate_proj/MatMul_output_0"]]
outputs = [tmap["/model/layers." + str(i) + "/mlp/act_fn/Mul_output_0"]]
graph.replace_with_silu(inputs, outputs)
inputs = [
tmap["onnx::Concat_" + str((i+1)*2)],
tmap["onnx::Concat_" + str((i+1)*2+1)],
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
outputs = [
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],]
inputs_added = [graph.inputs[1]]
outputs_removed = []
graph.replace_with_attention(
inputs, outputs, inputs_added, outputs_removed)
graph.outputs = [tmap[graph.outputs[0].name]]
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def special_pass(model):
graph = gs.import_onnx(model)
tmap = graph.tensors()
for node in graph.nodes:
if node.op == "Transpose" or node.op == "Reshape":
inp_node = node.i()
inp_node.outputs = node.outputs
node.outputs.clear()
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def convert_to_fp16(model):
graph = gs.import_onnx(model)
for node in graph.nodes:
if node.op == "Gather" and node.name == "/model/embed_tokens/Gather":
node.inputs[0].values = np.float16(node.inputs[0].values)
if node.op == "RMSNorm":
node.inputs[1].values = np.float16(node.inputs[1].values)
if node.op == "MatMul":
node.inputs[1].values = np.float16(node.inputs[1].values)
if node.name == "/lm_head/MatMul":
cast_1_out = gs.Variable(node.name+"_cast_out_output_0", dtype=np.float32, shape=node.outputs[0].shape)
cast_1 = gs.Node(op="Cast", inputs=[node.outputs[0]], outputs=[cast_1_out])
cast_1.attrs["to"] = np.float32
cast_1.name = node.name+"_cast_out_0"
graph.nodes.append(cast_1)
graph.outputs[0] = cast_1_out
node.outputs[0].dtype = np.float16
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def export_onnx(model: AutoModelForCausalLM):
if not os.path.exists(ONNX_MODEL_ORIGIN_PATH):
print("exporting origin onnx model...")
with torch.no_grad():
param = torch.zeros(
(args.batchsize, model.config.max_position_embeddings-1), dtype=torch.long)
logits = model(param, past_key_values=None)
if not args.is_1st_graph:
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
"position_ids": param_kvcache}), \
ONNX_MODEL_ORIGIN_PATH, verbose=False,
do_constant_folding=True,)
else:
position_ids = torch.tile(torch.arange(0, model.config.max_position_embeddings-1), (args.batchsize, 1))
attention_mask = torch.ones((args.batchsize, model.config.max_position_embeddings-1), dtype=torch.bool)
torch.onnx.export(model, (param, {"attention_mask": attention_mask,
"position_ids": position_ids}),\
ONNX_MODEL_ORIGIN_PATH, verbose=False,
do_constant_folding=True,)
print("export origin onnx finished.")
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SIM_PATH):
print("exporting sim onnx model...")
onnx_model = onnx.load(ONNX_MODEL_ORIGIN_PATH)
onnx_model = simplify(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_SIM_PATH, save_as_external_data=True, \
location="llama2_sim_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting sim onnx model finished.")
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_FUSION_PATH):
print("exporting fusion onnx model...")
onnx_model = onnx.load(ONNX_MODEL_SIM_PATH)
onnx_model = fusion(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_FUSION_PATH, save_as_external_data=True, \
location="llama2_fusion_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting fusion onnx model finished.")
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SPECIAL_PATH):
print("exporting special onnx model...")
onnx_model = onnx.load(ONNX_MODEL_FUSION_PATH)
onnx_model = special_pass(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_SPECIAL_PATH, save_as_external_data=True, \
location="llama2_special_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting special onnx model finished.")
if not args.is_1st_graph and args.fp16 and not os.path.exists(ONNX_MODEL_FP16_PATH):
print("exporting fp16 onnx model...")
onnx_model = onnx.load(ONNX_MODEL_SPECIAL_PATH)
onnx_model = convert_to_fp16(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_FP16_PATH, save_as_external_data=True, \
location="llama2_fp16_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting fp16 onnx model finished.")
print("world_size =", args.world_size)
if not args.is_1st_graph and args.world_size > 1 and not os.path.exists(ONNX_MODEL_DIST_PATH):
print("exporting dist onnx model...")
onnx_model = onnx.load(ONNX_MODEL_FP16_PATH) if args.fp16 else onnx.load(ONNX_MODEL_SPECIAL_PATH)
onnx_model = parallel_model(onnx_model, args.world_size, args.rank)
onnx.save(onnx_model, ONNX_MODEL_DIST_PATH, save_as_external_data=True, \
location="llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
args.batchsize, args.n_layers,
16 if args.fp16 else 32, args.world_size, args.rank))
print("exporting dist onnx model finished.")
def get_it_logit(onnx_model, input_ids):
# initialization
runtime = backend.CudaRuntime(args.rank)
runtime.init_comm(
"dist",
args.world_size,
args.rank,
)
print("[{}] comm init.".format(args.rank))
stub = OnnxStub(onnx_model, runtime)
print("[{}] stub init.".format(args.rank))
# warm up
for i in range(10):
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
print("[{}] stub warmup.".format(args.rank))
logits = np.zeros((args.batchsize, args.n_max_length, args.vocab_size), dtype=np.float32)
output_ids = np.zeros((args.batchsize, args.n_max_length), dtype=np.int64)
avg_inference_time = 0
t0 = time.time()
for i in tqdm(range(0, args.n_max_length)):
with nvtx.annotate("seq_length = {}".format(i), color="red"):
assert input_ids.shape[0] == args.batchsize
input_id = input_ids[:, i] if i < input_ids.shape[1] else output_ids[:, i-1]
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
# copyin input
with nvtx.annotate("[it] copyin", color="blue"):
(list(stub.inputs.items()))[0][1].copyin_int64(
input_id.reshape(-1).tolist())
(list(stub.inputs.items()))[1][1].copyin_int64(
position_id.reshape(-1).tolist())
# run
t10 = time.time()
with nvtx.annotate("[it] run", color="green"):
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
t11 = time.time()
avg_inference_time += (t11 - t10)
# copyout output
if not args.speedup:
with nvtx.annotate("[it] copyout", color="blue"):
logits[:,i, :] = np.array((list(stub.outputs.items()))[0][1].copyout_float()).reshape(args.batchsize, -1)
output_ids[:, i] = np.argmax(logits[:, i, :], -1).astype(np.int64)
t1 = time.time()
if args.rank == 0:
result = "[it] e2e: {} gpus, {} layers, e2e time: {:.2f}s, average inference time: {:.2f}ms"\
.format(args.num_nodes * args.nproc_per_node, args.n_layers, t1-t0, \
avg_inference_time*1000/args.n_max_length)
print(result)
del stub
return output_ids
if __name__ == "__main__":
torch_model = LlamaForCausalLM.from_pretrained(
PRETRAINED_LLAMA_PATH, num_hidden_layers=int(args.n_layers)).eval()
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LLAMA_PATH)
#prompt = "Hey, are you conscious? Can you talk to me?"
#prompt = "What is PTX?"
#prompt = "Tell me a joke."
#prompt = "What are the key principles of smart investing?"
prompt = "What is DeepSpeed?"
prompts=[prompt]*args.batchsize
inputs = tokenizer(prompts, return_tensors="pt")
input_ids = inputs.input_ids
print("prompt ids =", input_ids)
##########################################################
# inference with InfiniTensor
##########################################################
print("exporting onnx...")
export_onnx(torch_model)
print("exporting onnx finished.")
onnx_to_run_path = ONNX_MODEL_DIST_PATH if args.world_size > 1 else \
(ONNX_MODEL_FP16_PATH if args.fp16 else ONNX_MODEL_SPECIAL_PATH)
print("loading onnx", onnx_to_run_path, "...")
onnx_model = onnx.load(onnx_to_run_path)
print("loading onnx finished.")
output_ids_it = get_it_logit(onnx_model, input_ids)
it_output_text = tokenizer.batch_decode(output_ids_it[:, input_ids.shape[-1]:output_ids_it.shape[-1]])
if args.rank == 0:
for i in range(args.batchsize):
print("prompt: ", prompts[i])
print("answer: [it]", it_output_text[i])
##########################################################
# validation with pytorch
##########################################################
"""
generate_ids = torch_model.generate(inputs.input_ids, max_length=args.n_max_length)#, num_beams=4, do_sample=True)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"""
if not args.speedup and not args.is_1st_graph:
kvcache_torch = None
output_ids_pt = torch.zeros(args.batchsize, args.n_max_length).int() # + input_ids.shape[-1] - 1).int()
if args.fp16:
torch_model = torch_model.half()
torch_model = torch_model.cuda()
# print(torch.cuda.memory_summary())
avg_inference_time = 0
with torch.no_grad():
t0 = time.time()
for i in range(args.n_max_length):
input_id = input_ids[:,i] if i < input_ids.shape[1] else out_token
input_id = input_id.view(args.batchsize,1).cuda()
t00 = time.time()
outputs = torch_model(input_id, past_key_values=kvcache_torch)
t01 = time.time()
avg_inference_time += (t01-t00)
logits = outputs['logits']
kvcache_torch = outputs['past_key_values']
out_token = torch.argmax(logits, dim=-1)
output_ids_pt[:, i:i+1] = out_token
t1 = time.time()
avg_inference_time /= args.n_max_length
result = "[pt] e2e time: {:.2f}s, average inference time: {:.2f}ms"\
.format(t1-t0, avg_inference_time*1000)
if args.rank == 0:
print(result)
pt_output_text = tokenizer.batch_decode(output_ids_pt[:,input_ids.shape[-1]:args.n_max_length])
for i in range(args.batchsize):
print("[pt]", args.rank, pt_output_text[i])
if not args.is_1st_graph:
assert(output_ids_it.shape[-1] == args.n_max_length)
np.testing.assert_equal(output_ids_pt[:, input_ids.shape[-1]:args.n_max_length], output_ids_it[:,input_ids.shape[-1]:args.n_max_length])

View File

@ -3,14 +3,17 @@
#include <cstdio>
struct AttentionKVCacheMetadata {
int dimSize[4];
int stride[4];
int head_dim;
int num_heads;
int num_seqs;
int max_kv_seqlen;
};
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k, float *input_v,
int *position_id, float *output_matmul,
void attention_kvcache_kernel(int dType, void *input_k_cache,
void *input_v_cache, void *input_q, void *input_k,
void *input_v, int64_t *position_id,
void *output_matmul,
const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp);

View File

@ -5,8 +5,7 @@
namespace infini {
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride,
int pos_stride);
void rope_kernel(int dType, int64_t *pos, void *input, void *output,
int dim_model, int dim_head, int batchsize, int pos_stride);
}; // namespace infini

View File

@ -30,6 +30,10 @@ class AttentionKVCacheObj : public OperatorObj {
OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override {
return {inputs[2]->getDType()};
};
DataType getDType() const { return getInputs(2)->getDType(); }
std::string toString() const override;
int numInputs() const override { return 6; }

View File

@ -21,6 +21,10 @@ class RoPEObj : public OperatorObj {
int numOutputs() const override { return 1; }
DataType getDType() const { return getInputs(1)->getDType(); }
vector<DataType> inferDataType(const TensorVec &inputs) const override {
return {inputs[1]->getDType()};
};
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;

View File

@ -208,9 +208,32 @@ class OnnxStub:
op[1],
)
elif node.op_type == "MatMul":
if node.input[1] in data.keys() \
and to_array(data[node.input[1]]).dtype == np.float32 \
and 'cuda_runtime' in dir(backend) \
and tensors[node.input[0]].shape()[0] == 1 \
and tensors[node.input[0]].shape()[1] == 1 \
and len(tensors[node.input[1]].shape()) == 2 \
and node.input[1] in data.keys():
data[node.input[1]] = from_array(
np.transpose(to_array(data[node.input[1]])))
tensors[node.input[1]] = self.handler.tensor(
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
tensors[node.input[1]].dtype())
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]], # input
tensors[node.input[1]], # weight
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
True,
None,
backend.ActType.Linear,
matmul_compute_type,
)
else:
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,

View File

@ -7,33 +7,37 @@ namespace infini {
class AttentionKVCacheCompute {
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
Tensor input_v_cache,
Tensor position_id) const {
int nDims = input_v_cache->getRank();
auto strides = input_v_cache->getStride();
IT_ASSERT(nDims == 4);
IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
int dim_position_id = position_id->getRank();
metadata.num_seqs = 1;
for (int i = 0; i < dim_position_id; i++) {
metadata.num_seqs *= position_id->getDims().at(i);
}
metadata.head_dim = input_v_cache->getDims().at(3);
metadata.num_heads = input_v_cache->getDims().at(1);
metadata.max_kv_seqlen = input_v_cache->getDims().at(2);
}
public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul, CudaPtr p_workspace) const {
void do_compute(int dType, Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul,
CudaPtr p_workspace) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
initAttentionKVCacheMetadata(metadata, input_v_cache, position_id);
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
input_v_cache->getRawDataPtr<float *>(),
input_q->getRawDataPtr<float *>(),
input_k->getRawDataPtr<float *>(),
input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(),
metadata, (float *)p_workspace,
(float *)(p_workspace + (1ll << 30)));
attention_kvcache_kernel(
dType, input_k_cache->getRawDataPtr<void *>(),
input_v_cache->getRawDataPtr<void *>(),
input_q->getRawDataPtr<void *>(), input_k->getRawDataPtr<void *>(),
input_v->getRawDataPtr<void *>(),
position_id->getRawDataPtr<int64_t *>(),
output_matmul->getRawDataPtr<void *>(), metadata,
(float *)p_workspace, (float *)(p_workspace + (1ll << 30)));
}
};
@ -41,15 +45,17 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32);
auto op = as<AttentionKVCacheObj>(_op);
int dType = op->getDType().getIndex();
int position_idx_dtype = op->getInputs()[5]->getDTypeIndex();
IT_ASSERT(dType == 1 || dType == 10 || position_idx_dtype == 7);
size_t workspaceSize = 2ll << 30;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0], idxWsData);
do_compute(dType, op->getInputs()[0], op->getInputs()[1],
op->getInputs()[2], op->getInputs()[3], op->getInputs()[4],
op->getInputs()[5], op->getOutputs()[0], idxWsData);
}
};

View File

@ -1,171 +1,237 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 16
#define BLOCKSIZE_2 WARP_SIZE*4
#define MAX_PARTITION 1024
// ASSUME SEQ_LEN OF Q IS 1
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int* position_id,
template <class T>
__global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
T* input_v_cache,
T* input_q,
T* input_k,
T* input_v,
int64_t* position_id,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
half* output_O_temp,
float* output_sum_temp) {
int seq_length = position_id[0] + 1;
int seq_length = position_id[blockIdx.y] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
if(blockIdx.y >= stride)
if(blockIdx.z >= stride)
return;
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT;
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
int parallel_idx = blockIdx.x + blockIdx.y * gridDim.x;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
int idx_seq = blockIdx.z * SEQ_UNIT;
float ptr_V[SEQ_UNIT*4]; // V
float ptr_K[SEQ_UNIT*4]; // K
float ptr_Q[4]; // Q
float ptr_P[SEQ_UNIT] = {0};
half reg_V[4];
half reg_K[4];
half reg_Q[4];
float reg_P;
float ptr_O[4] = {0};
float ptr_sum[1] = {0};
float reg_O[4] = {0};
float reg_sum = 0;
float temp[4];
bool is_fp16 = sizeof(T) == 2 ? true : false;
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.head_dim;
// readin Q
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
if(!is_fp16){
#pragma unroll
for(int i = 0; i < 4; i += 2){
(float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
*((half2*)(&reg_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
}
}
else{
#pragma unroll
for(int i = 0; i < 4; i += 2){
(half2 &)reg_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
}
}
int common_idx = lane_id_x2 + (parallel_idx * compMeta.max_kv_seqlen * compMeta.head_dim);
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
reg_P = 0;
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.head_dim);
// readin K & V
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
#pragma unroll
for(int i = 0; i < 4; i += 2){
*((half2*)(&reg_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
*((half2*)(&reg_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
}
}
else{
if(!is_fp16){
#pragma unroll
for(int i = 0; i < 4; i += 2){
(float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&reg_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_K[i]));
(float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&reg_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_V[i]));
}
}
else{
#pragma unroll
for(int i = 0; i < 4; i += 2){
(half2 &)reg_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_K[i]));
(half2 &)reg_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_V[i]));
}
}
}
// Q*K
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
for (int i = 0; i < 4; i += 2){
(half2 &)reg_K[i] = (half2 &)reg_Q[i] * (half2 &)reg_K[i];
#pragma unroll
for (int i = 0; i < 4; i ++){
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
(half2 &)reg_K[i] += __shfl_xor_sync(0xffffffff, (half2 &)reg_K[i], offset);
}
(float2 &) temp[i] = __half22float2((half2 &)reg_K[i]);
reg_P += (temp[i] + temp[i+1]);
(float2 &) temp[i] = __half22float2((half2 &)reg_V[i]);
}
// div sqrt(d)
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
}
reg_P /= sqrt(128.0);
// softmax
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
reg_P = expf(reg_P);
reg_sum += reg_P;
// * V
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
for (int i = 0; i < 4; i ++)
reg_O[i] = fmaf(reg_P, temp[i], reg_O[i]);
}
#pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
}
reg_O[i] /= reg_sum;
#pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0];
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
if(lane_id == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
for(int i = 0; i < 4; i += 2)
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.z * compMeta.head_dim) + (parallel_idx * compMeta.head_dim * stride)] = __float22half2_rn((float2 &)reg_O[i]);
if(lane_id_x2 == 0){
output_sum_temp[blockIdx.z + parallel_idx * stride] = reg_sum;
}
}
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
float* output_matmul,
template <class T>
__global__ void _attention_kvcache_kernel_128_2(int64_t* position_id,
T* output_matmul,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
half* output_O_temp,
float* output_sum_temp) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int parallel_idx = blockIdx.x;
int offset = parallel_idx * compMeta.head_dim;
float ptr_O[4] = {0};
float ptr_O_sum[4] = {0};
float ptr_sum = 0;
float ptr_sum_temp;
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
bool is_fp16 = sizeof(T) == 2 ? true : false;
if(size == 1){
if(!is_fp16){
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
output_matmul[i + offset]
= __half2float(output_O_temp[i + offset]);
}
else{
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
output_matmul[i + offset]
= output_O_temp[i + offset];
}
return;
}
__shared__ float shm_sum_temp[MAX_PARTITION];
__shared__ float shm_sum[WARP_SIZE];
float temp_sum = 0;
#pragma unroll
for(int i = 0; i < size; i ++){
(float4 &)ptr_O[0]
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
ptr_sum += ptr_sum_temp;
for(int i = threadIdx.x; i < size; i += blockDim.x){
shm_sum_temp[i] = output_sum_temp[i + parallel_idx * size];
temp_sum += shm_sum_temp[i];
}
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
temp_sum += __shfl_down_sync(0xffffffff, temp_sum, offset);
if(lane_id == 0)
shm_sum[threadIdx.x/WARP_SIZE] = temp_sum;
__syncthreads();
temp_sum = lane_id < (size + WARP_SIZE - 1) / WARP_SIZE ? shm_sum[lane_id] : 0;
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
#pragma unroll
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
temp_sum += __shfl_xor_sync(0xffffffff, temp_sum, offset);
temp_sum = __fdividef(1.0f, temp_sum + 1e-6f);
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x){
float acc = 0.0f;
for(int j = 0; j < size; j ++){
acc = fma(__half2float(output_O_temp[i + (j * compMeta.head_dim) + offset * size]) * shm_sum_temp[j], temp_sum, acc);
}
if(!is_fp16){
output_matmul[i + offset] = acc;
}
else{
output_matmul[i + offset] = __float2half(acc);
}
}
}
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k,
float *input_v, int *position_id, float *output_matmul,
void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cache,
void *input_q, void *input_k,
void *input_v, int64_t *position_id, void *output_matmul,
const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp) {
IT_ASSERT(compMeta.dimSize[3] == 128);
IT_ASSERT(dType == 1 || dType == 10);
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
dim3 blockDim(BLOCKSIZE, 1);
int gridsize_y = (compMeta.max_kv_seqlen - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.num_heads, compMeta.num_seqs, gridsize_y);
dim3 blockDim(WARP_SIZE, 1);
_attention_kvcache_kernel_128_1
if(dType == 1){
_attention_kvcache_kernel_128_1<float>
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
compMeta, output_O_temp, output_sum_temp);
((float*)input_k_cache, (float*)input_v_cache, (float*)input_q, (float*)input_k, (float*)input_v,
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
_attention_kvcache_kernel_128_2<float>
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
0, CUDAStream::getCurrentStream()>>>
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
(position_id, (float*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}
else{
_attention_kvcache_kernel_128_1<half>
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
((half*)input_k_cache, (half*)input_v_cache, (half*)input_q, (half*)input_k, (half*)input_v,
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<half>
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
0, CUDAStream::getCurrentStream()>>>
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}
}
} // namespace infini

View File

@ -36,7 +36,7 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
if (cuDataType == CUDA_R_16F) {
return CUBLAS_COMPUTE_32F_FAST_16F;
return CUBLAS_COMPUTE_16F;
} else if (cuDataType == CUDA_R_16BF) {
return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (cuDataType == CUDA_R_32F) {

View File

@ -18,17 +18,19 @@ class RoPECuda : public CudaKernelWithoutConfig {
const auto &inputShape = input->getDims();
int nDims = input->getDims().size();
int size = input->size();
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
IT_ASSERT(inputShape[0] == pos->getDims()[0] &&
inputShape[1] == pos->getDims()[1]);
int position_idx_dtype = op->getInputs()[0]->getDTypeIndex();
IT_ASSERT(position_idx_dtype == 7);
int dim_model = inputShape[2];
int dim_head = 128;
int hidden_stride = dim_model * inputShape[1];
int dim_head = 128; // TODO: get dim_head from the framework
int pos_stride = inputShape[1];
int batchsize = inputShape[0];
const int dType = op->getDType().getIndex();
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
size, dim_model, dim_head, hidden_stride, pos_stride);
rope_kernel(dType, pos->getRawDataPtr<int64_t *>(), inputData,
outputData, dim_model, dim_head, batchsize, pos_stride);
}
};

View File

@ -4,13 +4,15 @@
#include "utils/small_array.h"
template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
int dim_head, int hidden_stride, int pos_stride) {
__global__ void _rope_kernel(int64_t* pos, void *in, void *out, int dim_model,
int dim_head, int batchsize, int pos_stride) {
int batch_id = blockIdx.x;
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
int ith = blockIdx.z * blockDim.x + threadIdx.x;
int col = ith % dim_head;
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
int batch_stride = pos_stride * dim_model;
int offset = batch_id * batch_stride + blockIdx.y * dim_model;
if (ith >= dim_model)
return;
@ -34,7 +36,7 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
#define CASE(T) \
_rope_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
(pos, input, output, dim_model, dim_head, batchsize, pos_stride);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
@ -79,10 +81,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
}
namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
void rope_kernel(int dType, int64_t * pos, void *input, void *output,
int dim_model, int dim_head, int batchsize, int pos_stride) {
dim3 blocksize = dim3(32,1,1);
dim3 gridsize = dim3(1, 1, dim_model/32);
dim3 gridsize = dim3(batchsize, pos_stride, dim_model/32);
SWITCH_DTYPE(dType)
}