From 4a5b9572bb331894aa81b8fc9cec8bc8e2eb7828 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Wed, 10 Apr 2024 16:23:02 +0800 Subject: [PATCH] add test scripts for llama2 and 9G models --- examples/python/test_9G.py | 512 ++++++++++++++++++++++++++++++ examples/python/test_llama2_7b.py | 491 ++++++++++++++++++++++++++++ 2 files changed, 1003 insertions(+) create mode 100644 examples/python/test_9G.py create mode 100644 examples/python/test_llama2_7b.py diff --git a/examples/python/test_9G.py b/examples/python/test_9G.py new file mode 100644 index 00000000..6fc96f7f --- /dev/null +++ b/examples/python/test_9G.py @@ -0,0 +1,512 @@ +import os +from pyinfinitensor.onnx import OnnxStub, backend +import numpy as np +import onnx +import torch +from tqdm import tqdm +import onnx_graphsurgeon as gs +import time +import nvtx +import argparse +from mpi4py import MPI +from pytrie import StringTrie +import io +import json +import re +from typing import ( + Dict, + List, + IO, +) + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--batchsize', dest='batchsize', type=int, default=1) +parser.add_argument('--layer', dest='n_layers', type=int, default=48) +parser.add_argument("--num_nodes", dest='num_nodes', + type=int, default=1, help="number of nodes") +parser.add_argument("--world_size", dest="world_size", + type=int, default=1, help="") +parser.add_argument("--nproc_per_node", dest="nproc_per_node", + type=int, default=1, help="number of processes per node") +parser.add_argument("--n_max_length", dest="n_max_length", + type=int, default=1024, help="number of processes per node") +parser.add_argument("--vocab_size", dest="vocab_size", + type=int, default=119696, help="vocabulary size") +parser.add_argument("--hidden_size", dest="hidden_size", + type=int, default=4096, help="vocabulary size") +parser.add_argument('--rank', dest='rank', type=int, default=0) +parser.add_argument('--speedup', action='store_true') +parser.add_argument('--no_cudagraph', action='store_true') +parser.add_argument('--fp16', action='store_true') +args = parser.parse_args() +comm = MPI.COMM_WORLD +args.rank = comm.Get_rank() +args.nproc_per_node = comm.Get_size() +args.world_size = args.num_nodes * args.nproc_per_node + +ONNX_MODEL_PATH = "/data3/shared/xnsong/9G/dist/9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format( + args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank) + +weight_path = "9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format( + args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank) + +model_dir = "/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/cpm9g-11b-sft.pt" + +@gs.Graph.register() +def RMSNorm(self, a, b): + return self.layer(op="RMSNorm", inputs=a, outputs=b) + +@gs.Graph.register() +def RoPE(self, a, b): + return self.layer(op="RoPE", inputs=a, outputs=b) + +@gs.Graph.register() +def AttentionKVCache(self, a, b): + return self.layer(op="AttentionKVCache", inputs=a, outputs=b) + +def to_numpy(dict): + ret = dict + if args.fp16: + ret = np.float16(ret) + else: + ret = np.float32(ret) + return ret + +def parallel(array, split='replicate'): + if args.world_size > 1 and split == 'partial_column': + return np.hsplit(array, args.world_size)[args.rank] + elif args.world_size > 1 and split == 'partial_row': + return np.vsplit(array, args.world_size)[args.rank] + return array + + +def generate_onnx(ONNX_MODEL_PATH): + state_dict = torch.load(f'{model_dir}', map_location='cpu') + new_state_dict = {name: param.cpu().numpy() + for name, param in state_dict.items() + } + + operators = [] + graph = gs.Graph(nodes=operators) + gather_input = gs.Variable(name="gather_input.0", dtype=np.int64, shape=(1,1)) + pos_input = gs.Variable(name="pos_input.0", dtype=np.int64, shape=(1,1)) + + embedding_weight = gs.Constant(name="embedding.weight", values=to_numpy(new_state_dict["input_embedding.weight"])) + gather_output = gs.Variable(name="gather_output.0") + gather = gs.Node(op="Gather", inputs=[embedding_weight, gather_input], outputs=[gather_output]) + operators.append(gather) + input = gather_output + + graph.inputs=[gather_input, pos_input] + graph.outputs=[] + + for i in tqdm(range(args.n_layers)): + # global input + attn_kcache_input = gs.Variable(name="/layers." + str(i) + "/attn/kcache_input", dtype=np.float32, shape=(1,32,1023,128)) + attn_vcache_input = gs.Variable(name="/layers." + str(i) + "/attn/vcache_input", dtype=np.float32, shape=(1,32,1023,128)) + graph.inputs.append(attn_kcache_input) + graph.inputs.append(attn_vcache_input) + + # weight + layernorm_0_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.0/mul_weight", + values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".self_att.layernorm_before_attention.weight"])) + attn_qproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/qproj_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_q.weight"])) + , 'partial_column')) + attn_kproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/kproj_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_k.weight"])) + , 'partial_column')) + attn_vproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/vproj_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_v.weight"])) + , 'partial_column')) + attn_outmatmul_input = gs.Constant(name="/layers." + str(i) + "/attn/outmatmul_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.attention_out.weight"])) + , 'partial_row')) + + layernorm_1_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.1/mul_weight", + values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".ffn.layernorm_before_ffn.weight"])) + ffn_matmul_0_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_0_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_0.weight"])) + , 'partial_column')) + ffn_matmul_1_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_1_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_1.weight"])) + , 'partial_column')) + ffn_matmul_out_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_out_weight", + values=parallel( + np.transpose( + to_numpy( + new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_out.weight"])) + , 'partial_row')) + + attn_qrope_output = gs.Variable(name="/layers." + str(i) + "/attn/qrope_output") + attn_krope_output = gs.Variable(name="/layers." + str(i) + "/attn/krope_output") + attn_kvcache_output = gs.Variable(name="/layers." + str(i) + "/attn/kvcache_output") + layernorm_0_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.0/mul_output_1") + layernorm_1_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.1/mul_output_1") + attn_qproj_output = gs.Variable(name="/layers." + str(i) + "/attn/qproj_output") + attn_kproj_output = gs.Variable(name="/layers." + str(i) + "/attn/kproj_output") + attn_vproj_output = gs.Variable(name="/layers." + str(i) + "/attn/vproj_output") + attn_outmatmul_output = gs.Variable(name="/layers." + str(i) + "/attn/outmatmul_output") + attn_outadd_output = gs.Variable(name="/layers." + str(i) + "/attn/outadd_output") + ffn_matmul_0_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_0_output") + ffn_silu_output = gs.Variable(name="/layers." + str(i) + "/ffn/silu_output") + ffn_matmul_1_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_1_output") + ffn_mul_output = gs.Variable(name="/layers." + str(i) + "/ffn/mul_output") + ffn_matmul_out_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_out_output") + ffn_add_output = gs.Variable(name="/layers." + str(i) + "/ffn/add_output") + + graph.RMSNorm([input, layernorm_0_mul_weight], [layernorm_0_mul_output_1]) + attn_qproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_qproj_weight], outputs=[attn_qproj_output]) + operators.append(attn_qproj) + attn_kproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_kproj_weight], outputs=[attn_kproj_output]) + operators.append(attn_kproj) + attn_vproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_vproj_weight], outputs=[attn_vproj_output]) + operators.append(attn_vproj) + graph.RoPE([pos_input, attn_qproj_output], [attn_qrope_output]) + graph.RoPE([pos_input, attn_kproj_output], [attn_krope_output]) + graph.AttentionKVCache([attn_kcache_input, attn_vcache_input, attn_qrope_output, attn_krope_output, attn_vproj_output, pos_input],[attn_kvcache_output]) + attn_outproj = gs.Node(op="MatMul", inputs=[attn_kvcache_output, attn_outmatmul_input], outputs=[attn_outmatmul_output]) + operators.append(attn_outproj) + + attn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/attn/reducesum_output") + if args.world_size > 1: + reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/attn/reducesum", + inputs=[attn_outmatmul_output], outputs=[attn_reduce_sum_output], + attrs={"noop_with_empty_axes":1, "communicator":0}) + graph.nodes.append(reduce_sum) + + attn_outadd = gs.Node(op="Add", inputs=[input, attn_outmatmul_output if args.world_size == 1 else attn_reduce_sum_output], outputs=[attn_outadd_output]) + operators.append(attn_outadd) + + graph.RMSNorm([attn_outadd_output, layernorm_1_mul_weight], [layernorm_1_mul_output_1]) + + ffn_matmul_0 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_0_input], outputs=[ffn_matmul_0_output]) + operators.append(ffn_matmul_0) + ffn_silu = gs.Node(op="Silu", inputs=[ffn_matmul_0_output], outputs=[ffn_silu_output]) + operators.append(ffn_silu) + ffn_matmul_1 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_1_input], outputs=[ffn_matmul_1_output]) + operators.append(ffn_matmul_1) + ffn_mul = gs.Node(op="Mul", inputs=[ffn_silu_output, ffn_matmul_1_output], outputs=[ffn_mul_output]) + operators.append(ffn_mul) + ffn_matmul_out = gs.Node(op="MatMul", inputs=[ffn_mul_output, ffn_matmul_out_input], outputs=[ffn_matmul_out_output]) + operators.append(ffn_matmul_out) + + ffn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/ffn/reducesum_output") + if args.world_size > 1: + reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/ffn/reducesum", + inputs=[ffn_matmul_out_output], outputs=[ffn_reduce_sum_output], + attrs={"noop_with_empty_axes":1, "communicator":0}) + graph.nodes.append(reduce_sum) + + ffn_add = gs.Node(op="Add", inputs=[attn_outadd_output, ffn_matmul_out_output if args.world_size == 1 else ffn_reduce_sum_output], outputs=[ffn_add_output]) + operators.append(ffn_add) + input = ffn_add_output + + layernorm_mul_weight = gs.Constant(name="/output/layernorm/mul_weight", values=to_numpy(new_state_dict["encoder.output_layernorm.weight"])) + layernorm_mul_output_1 = gs.Variable(name="/output/layernorm/mul_output_1") + + graph.RMSNorm([input, layernorm_mul_weight], [layernorm_mul_output_1]) + + lm_head_weight = gs.Constant(name="/output/lm_head/weight", values=np.transpose(to_numpy(new_state_dict["lm_head.weight"]))) + lm_head_output = gs.Variable(name="/output/lm_head/output") + lm_head = gs.Node(op="MatMul", inputs=[layernorm_mul_output_1, lm_head_weight], outputs=[lm_head_output]) + operators.append(lm_head) + + if args.fp16: + final_cast_output = gs.Variable(name="/output/cast/output", dtype=np.float32, shape=(1,1,args.vocab_size)) + final_cast = gs.Node(op="Cast", inputs=[lm_head_output], outputs=[final_cast_output]) + final_cast.attrs["to"] = np.float32 + operators.append(final_cast) + graph.outputs.append(final_cast_output) + else: + lm_head_output.dtype=np.float32 + lm_head_output.shape=(1,1,args.vocab_size) + graph.outputs.append(lm_head_output) + + onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True, location=weight_path) + return + + +def load_vocab(fp: IO[bytes]) -> Dict[str, int]: + """Loads a vocabulary file into a dictionary.""" + vocab: Dict[str, int] = {} + + reader = io.TextIOWrapper(fp, encoding="utf-8") + for token in reader.readlines(): + token = token.strip() + if len(token) == 0: + continue + token = json.loads(token) + vocab[token] = len(vocab) + return vocab + + +class CPM9GTokenizer(object): + def __init__(self, path): + self.unk_token = "" + self.bos_token = "" + self.eos_token = "" + self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [ + "<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100) + ] + + self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list) + + all_tokens = load_vocab(io.FileIO(path, "rb")) + + self.encoder: Dict[str, int] = {} + self._special_encoder: Dict[str, int] = {} + for token, token_id in all_tokens.items(): + if token in self._special_token_set: + self._special_encoder[token] = token_id + else: + self.encoder[token] = token_id + + self.decoder = {v: k for k, v in self.encoder.items()} + self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)} + + self._max_word_len = max([len(x) for x in self.encoder.keys()]) + + self._len_word_first = {} + for x in self.encoder.keys(): + if not x[0] in self._len_word_first: + self._len_word_first[x[0]] = 1 + if len(x) > self._len_word_first[x[0]]: + self._len_word_first[x[0]] = len(x) + self.tencoder = StringTrie(self.encoder) + + def get_piece(self, text: str) -> str: + if text[0] in self._len_word_first: + text = text[: self._len_word_first[text[0]]] + len_text = len(text) + for i in range(len(text)): + sub = text[: len_text - i] + if sub in self.encoder: + return sub + return text[0] + + @property + def vocab_size(self): + return len(self) + + @property + def eos_id(self): + return self._special_encoder[self.eos_token] + + @property + def bos_id(self): + return self._special_encoder[self.bos_token] + + @property + def unk_id(self): + return self._special_encoder[self.unk_token] + + def __len__(self): + return len(self.encoder) + len(self._special_encoder) + + def tokenize(self, text: str) -> List[str]: + output_tokens: List[str] = [] + st = 0 + while st < len(text): + piece = self.get_piece(text[st:]) + output_tokens.append(piece) + st += len(piece) + return output_tokens + + @staticmethod + def escape(text: str) -> str: + return text + + @staticmethod + def unescape(text: str) -> str: + return text + + def encode(self, text: str, with_bos = True) -> List[int]: + ret = [] + if with_bos: + ret.append(self.bos_id) + for x in self.tokenize(text): + if x in self.encoder: + ret.append(self.encoder[x]) + else: + ret.extend(self._encode_unicode(x)) + return ret + + def decode(self, tokens: List[int]): + """Decode ids into a string.""" + ret = [] + st = 0 + while st < len(tokens): + if tokens[st] in self.decoder: + ret.append(self.decoder[tokens[st]]) + st += 1 + elif tokens[st] in self._byte_decoder: + first = self._byte_decoder[tokens[st]] + length = 1 if first < 128 else len(re.search('^1+0', bin(first)[2:])[0])-1 + code = 0 + try: + for j in range(length): + code = code << 8 | self._byte_decoder[tokens[st + j]] + code = int.to_bytes(code, length, "big").decode("utf-8") + ret.append(code) + except: + pass + st = st + length + elif tokens[st] == self.eos_id: + ret.append(self.eos_token) + st += 1 + elif tokens[st] == self.bos_id: + ret.append(self.bos_token) + st += 1 + else: + ret.append(self.unk_token) + st += 1 + return "".join(ret) + + def _encode_unicode(self, token): + # wrap unicode encoding into a helper function + ids = [] + utf8_id = token.encode("utf-8") + for _id in utf8_id: + ids.append(self._special_encoder[self.byte_list[_id]]) + return ids + + def next_token(self, text): + # fast next token matching + token, token_id = self.tencoder.longest_prefix_item(text, (None, None)) + if token is None: + token = text[0] + token_ids = self._encode_unicode(token) + else: + token_ids = [token_id] + return token, token_ids + + +def start_worker( + world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, query +): + model = onnx.load(ONNX_MODEL_PATH) + runtime = backend.CudaRuntime(local_rank) + if args.nproc_per_node > 1: + runtime.init_comm( + "9g", + world_size, + rank, + ) + print("[{}] comm init.".format(rank)) + + stub = OnnxStub(model, runtime) + print("[{}] stub init.".format(rank)) + + for i in range(10): + if args.no_cudagraph: + stub.run() + else: + stub.run_with_cudagraph() + print("[{}] stub warmup.".format(rank)) + + tokenizer = CPM9GTokenizer("/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/vocabs.txt") + query = tokenizer.encode(query) + + output_tokens = [] + for i in range(len(query)): + q = np.array(query[i]) + (list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist()) + pos = i * np.ones((args.batchsize, 1), dtype=np.int64) + (list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist()) + + if args.no_cudagraph: + stub.run() + else: + stub.run_with_cudagraph() + + if i == len(query) - 1: + output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \ + else np.array((list(stub.outputs.items()))[-1][1].copyout_float()) + q = np.argmax(output) + output_tokens.append(q) + + avg_time = 0 + count = 0 + while i < 1000: + count = count + 1 + torch.cuda.synchronize() + with nvtx.annotate("gen {}-th token".format(i), color="red"): + i = i + 1 + (list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist()) + pos = i * np.ones((args.batchsize, 1), dtype=np.int64) + (list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist()) + + t0 = time.time() + if args.no_cudagraph: + stub.run() + else: + stub.run_with_cudagraph() + t1 = time.time() + avg_time += t1 - t0 + + output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \ + else np.array((list(stub.outputs.items()))[-1][1].copyout_float()) + + # print(output) + + with nvtx.annotate("argmax".format(i), color="green"): + q = np.argmax(output) + if q == 2: + break + + output_tokens.append(q) + avg_time = avg_time / count + print("avg_time_cost =", avg_time*1000, "ms") + text = tokenizer.decode(output_tokens) + return text + + +if __name__ == "__main__": + comm = MPI.COMM_WORLD + args.rank = comm.Get_rank() + args.nproc_per_node = comm.Get_size() + world_size = args.num_nodes * args.nproc_per_node + + if not os.path.exists(ONNX_MODEL_PATH): + print("exporting onnx graph") + generate_onnx(ONNX_MODEL_PATH) + else: + print("will use exsiting onnx graph") + onnx_model = onnx.load(ONNX_MODEL_PATH) + print("data loaded") + + + #query = '''Beijing is the captial''' + #query = '''什么是PTX?''' + #query = '''生病了怎么办?''' + #query = '''Happy''' + query = '''def gcd(a, b):''' + + #################################### + # infinitensor dist + #################################### + # run distributed parallel. + pred = start_worker(world_size, args.rank, args.rank % + args.nproc_per_node, onnx_model, query) + if args.rank == 0: + print("输入:\n\n", query, "\n") + print("输出:", pred) diff --git a/examples/python/test_llama2_7b.py b/examples/python/test_llama2_7b.py new file mode 100644 index 00000000..98a699eb --- /dev/null +++ b/examples/python/test_llama2_7b.py @@ -0,0 +1,491 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM +from tqdm import tqdm +import argparse +import torch +import onnx +import onnx_graphsurgeon as gs +import os +import numpy as np +from pyinfinitensor.onnx import OnnxStub, backend +import time +import nvtx +from mpi4py import MPI + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--batchsize', dest='batchsize', type=int, default=1) +parser.add_argument('--layer', dest='n_layers', type=int, default=32) +parser.add_argument("--num_nodes", dest='num_nodes', + type=int, default=1, help="number of nodes") +parser.add_argument("--nproc_per_node", dest="nproc_per_node", + type=int, default=1, help="number of processes per node") +parser.add_argument("--world_size", dest="world_size", + type=int, default=1, help="") +parser.add_argument("--n_max_length", dest="n_max_length", + type=int, default=1024, help="") +parser.add_argument("--vocab_size", dest="vocab_size", + type=int, default=32000, help="vocabulary size") +parser.add_argument("--hidden_size", dest="hidden_size", + type=int, default=4096) +parser.add_argument("--head_size", dest="head_size", + type=int, default=32) +parser.add_argument("--head_dim", dest="head_dim", + type=int, default=128) +parser.add_argument('--rank', dest='rank', type=int, default=0) +parser.add_argument('--no_cudagraph', action='store_true') +parser.add_argument('--fp16', action='store_true') +parser.add_argument('--is_1st_graph', action='store_true') +parser.add_argument('--speedup', action='store_true') +args = parser.parse_args() + +comm = MPI.COMM_WORLD +args.rank = comm.Get_rank() +args.nproc_per_node = comm.Get_size() +args.world_size = args.num_nodes * args.nproc_per_node + +PRETRAINED_LLAMA_PATH = "/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/" +ONNX_MODEL_PATH = "/data3/shared/xnsong/llama2/" + ("1st" if args.is_1st_graph else "2nd") +ONNX_MODEL_ORIGIN_PATH = ONNX_MODEL_PATH + "/origin/llama2_origin_bs{}_layer{}.onnx".format( + args.batchsize, args.n_layers) +ONNX_MODEL_SIM_PATH = ONNX_MODEL_PATH + "/sim/llama2_sim_bs{}_layer{}.onnx".format( + args.batchsize, args.n_layers) +ONNX_MODEL_FUSION_PATH = ONNX_MODEL_PATH + "/fusion/llama2_fusion_bs{}_layer{}.onnx".format( + args.batchsize, args.n_layers) +ONNX_MODEL_SPECIAL_PATH = ONNX_MODEL_PATH + "/special/llama2_special_bs{}_layer{}.onnx".format( + args.batchsize, args.n_layers) +ONNX_MODEL_FP16_PATH = ONNX_MODEL_PATH + "/fp16/llama2_fp16_bs{}_layer{}.onnx".format( + args.batchsize, args.n_layers) +ONNX_MODEL_DIST_PATH = ONNX_MODEL_PATH + "/dist/llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format( + args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank) + +def parallel_model(onnx_model, world_size, rank): + graph = gs.import_onnx(onnx_model) + tmap = graph.tensors() + + for i in range(args.n_layers): + tmap[graph.inputs[2+i*2].name].shape[1] = tmap[graph.inputs[2+i*2].name].shape[1]//world_size + tmap[graph.inputs[3+i*2].name].shape[1] = tmap[graph.inputs[3+i*2].name].shape[1]//world_size + for node in graph.nodes: + if node.name == "/model/layers." + str(i) + "/self_attn/q_proj/MatMul": + node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank] + elif node.name == "/model/layers." + str(i) + "/self_attn/k_proj/MatMul": + node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank] + elif node.name == "/model/layers." + str(i) + "/self_attn/v_proj/MatMul": + node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank] + elif node.name == "/model/layers." + str(i) + "/self_attn/o_proj/MatMul": + node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank] + reduce_sum_output = gs.Variable("reduce_sum_output_" + str(i) + "_0", + dtype=np.float32) + reduce_sum = gs.Node(op="ReduceSum", name="reduce_sum_"+str(i)+"_0", + inputs=node.outputs, outputs=[reduce_sum_output], + attrs={"noop_with_empty_axes":1, "communicator":0}) + graph.nodes.append(reduce_sum) + next_node = node.outputs[0].outputs[0] + next_node.inputs[1] = reduce_sum_output + elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_0" or \ + node.name == "/model/layers." + str(i) + "/self_attn/Reshape_1": + node.inputs[1].values = np.array( + [1, 1, + args.head_size//world_size, + args.hidden_size//args.head_size]) + elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_2": + node.inputs[1] = gs.Constant(name="/model/layers."+str(i)+"/self_attn/vreshape_input", + values=np.array( + [1, 1, + args.head_size//world_size, + args.hidden_size//args.head_size])) + elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_3": + node.inputs[1] = gs.Constant(name="/model/layers." + str(i) + "/self_attn/Reshape_3_shape", + values=np.array( + [1, 1, args.hidden_size//world_size])) + + elif node.name == "/model/layers." + str(i) + "/mlp/up_proj/MatMul": + node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank] + elif node.name == "/model/layers." + str(i) + "/mlp/gate_proj/MatMul": + node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank] + elif node.name == "/model/layers." + str(i) + "/mlp/down_proj/MatMul": + node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank] + reduce_sum_output_1 = gs.Variable("reduce_sum_output_" + str(i) + "_1", + dtype=np.float32) + reduce_sum_1 = gs.Node(op="ReduceSum", inputs=node.outputs, outputs=[reduce_sum_output_1], + attrs={"noop_with_empty_axes":1, "communicator":0}) + graph.nodes.append(reduce_sum_1) + next_node = node.outputs[0].outputs[0] + next_node.inputs[1] = reduce_sum_output_1 + + # new_out_1 = tmap["/model/layers.0/mlp/down_proj/MatMul_output_0"] #reduce_sum_output + # new_out_1.dtype = np.float32 + # new_out_1.shape = [1,1,4096] + # graph.outputs.append(new_out_1) + graph.cleanup(True).toposort() + return gs.export_onnx(graph) + +def simplify(onnx_model): + graph = gs.import_onnx(onnx_model) + for node in graph.nodes: + if node.op == "Cast": + inp_node = node.i() + inp_node.outputs = node.outputs + node.outputs.clear() + + for i in range(args.n_layers): + nodename = "/model/layers." + str(i) + "/self_attn/Add_2" + node = [node for node in graph.nodes if node.name == nodename][0] + inp_node = node.i() + inp_node.outputs = node.outputs + node.outputs.clear() + + graph.cleanup().toposort() + return gs.export_onnx(graph) + +@gs.Graph.register() +def replace_with_RMSNorm(self, inputs, outputs): + inputs[0].outputs.pop(0) + inputs[0].outputs.pop(0) + + for out in outputs: + out.inputs.clear() + return self.layer(op="RMSNorm", inputs=inputs, outputs=outputs, name="rmsnorm") + +@gs.Graph.register() +def replace_with_silu(self, inputs, outputs): + for inp in inputs: + inp.outputs.clear() + for out in outputs: + out.inputs.clear() + return self.layer(op="Silu", inputs=inputs, outputs=outputs, name="silu") + +@gs.Graph.register() +def replace_with_RoPE(self, a, b): + return self.layer(op="RoPE", inputs=a, outputs=b, name="rope") + +@gs.Graph.register() +def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed): + for inp in inputs: + inp.outputs.clear() + for out in outputs: + out.inputs.clear() + for inp in inputs_added: + inputs.append(inp) + for out in outputs_removed: + out.inputs.clear() + return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs, name="attention") + +def fusion(model): + graph = gs.import_onnx(model) + tmap = graph.tensors() + + tmap["onnx::Reshape_1"].outputs.clear() + + inputs = [tmap["/model/layers.0/input_layernorm/Cast_output_0"], tmap["model.layers.0.input_layernorm.weight"]] + rmsnorm_outputs = [tmap["/model/layers.0/input_layernorm/Mul_1_output_0"]] + graph.replace_with_RMSNorm(inputs, rmsnorm_outputs) + + for i in range(args.n_layers): + # rotary embedding op + tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"].inputs.clear() + tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"].inputs.clear() + attn_qreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/qreshape_input", + values=np.array([1,1,args.head_size,args.hidden_size//args.head_size])) + attn_kreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/kreshape_input", + values=np.array([1,1,args.head_size,args.hidden_size//args.head_size])) + attn_qrope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qrope_output") + attn_krope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/krope_output") + attn_qreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qreshape_output") + attn_kreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/kreshape_output") + + attn_qreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_0", inputs=[attn_qrope_output, attn_qreshape_input], outputs=[attn_qreshape_output]) + attn_kreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_1", inputs=[attn_krope_output, attn_kreshape_input], outputs=[attn_kreshape_output]) + attn_qtrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_qreshape_output], + outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"]]) + attn_ktrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_kreshape_output], + outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"]]) + + graph.nodes.append(attn_qreshape) + graph.nodes.append(attn_kreshape) + graph.nodes.append(attn_qtrans) + graph.nodes.append(attn_ktrans) + inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/q_proj/MatMul_output_0"]] + graph.replace_with_RoPE(inputs, [attn_qrope_output]) + inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/k_proj/MatMul_output_0"]] + graph.replace_with_RoPE(inputs, [attn_krope_output]) + + # rms-norm op + inputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Cast_output_0"], \ + tmap["model.layers." + str(i) + ".post_attention_layernorm.weight"]] + outputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Mul_1_output_0"]] + graph.replace_with_RMSNorm(inputs, outputs) + inputs = [tmap["/model/layers." + str(i+1) + "/input_layernorm/Cast_output_0"] if i != args.n_layers-1 else \ + tmap["/model/norm/Cast_output_0"], \ + tmap["model.layers." + str(i+1) + ".input_layernorm.weight"] if i != args.n_layers-1 else \ + tmap["model.norm.weight"]] + outputs = [tmap["/model/layers."+ str(i+1) + "/input_layernorm/Mul_1_output_0"]] if i != args.n_layers-1 else \ + [tmap["/model/norm/Mul_1_output_0"]] + graph.replace_with_RMSNorm(inputs, outputs) + + # silu op + inputs = [tmap["/model/layers." + str(i) + "/mlp/gate_proj/MatMul_output_0"]] + outputs = [tmap["/model/layers." + str(i) + "/mlp/act_fn/Mul_output_0"]] + graph.replace_with_silu(inputs, outputs) + + inputs = [ + tmap["onnx::Concat_" + str((i+1)*2)], + tmap["onnx::Concat_" + str((i+1)*2+1)], + tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] + outputs = [ + tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],] + + inputs_added = [graph.inputs[1]] + outputs_removed = [] + graph.replace_with_attention( + inputs, outputs, inputs_added, outputs_removed) + + graph.outputs = [tmap[graph.outputs[0].name]] + graph.cleanup(True).toposort() + + return gs.export_onnx(graph) + +def special_pass(model): + graph = gs.import_onnx(model) + tmap = graph.tensors() + for node in graph.nodes: + if node.op == "Transpose" or node.op == "Reshape": + inp_node = node.i() + inp_node.outputs = node.outputs + node.outputs.clear() + graph.cleanup(True).toposort() + return gs.export_onnx(graph) + +def convert_to_fp16(model): + graph = gs.import_onnx(model) + + for node in graph.nodes: + if node.op == "Gather" and node.name == "/model/embed_tokens/Gather": + node.inputs[0].values = np.float16(node.inputs[0].values) + + if node.op == "RMSNorm": + node.inputs[1].values = np.float16(node.inputs[1].values) + + if node.op == "MatMul": + node.inputs[1].values = np.float16(node.inputs[1].values) + if node.name == "/lm_head/MatMul": + cast_1_out = gs.Variable(node.name+"_cast_out_output_0", dtype=np.float32, shape=node.outputs[0].shape) + cast_1 = gs.Node(op="Cast", inputs=[node.outputs[0]], outputs=[cast_1_out]) + cast_1.attrs["to"] = np.float32 + cast_1.name = node.name+"_cast_out_0" + graph.nodes.append(cast_1) + graph.outputs[0] = cast_1_out + node.outputs[0].dtype = np.float16 + + graph.cleanup(True).toposort() + return gs.export_onnx(graph) + +def export_onnx(model: AutoModelForCausalLM): + if not os.path.exists(ONNX_MODEL_ORIGIN_PATH): + print("exporting origin onnx model...") + with torch.no_grad(): + param = torch.zeros( + (args.batchsize, model.config.max_position_embeddings-1), dtype=torch.long) + logits = model(param, past_key_values=None) + + if not args.is_1st_graph: + param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long) + torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values, + "position_ids": param_kvcache}), \ + ONNX_MODEL_ORIGIN_PATH, verbose=False, + do_constant_folding=True,) + else: + position_ids = torch.tile(torch.arange(0, model.config.max_position_embeddings-1), (args.batchsize, 1)) + attention_mask = torch.ones((args.batchsize, model.config.max_position_embeddings-1), dtype=torch.bool) + torch.onnx.export(model, (param, {"attention_mask": attention_mask, + "position_ids": position_ids}),\ + ONNX_MODEL_ORIGIN_PATH, verbose=False, + do_constant_folding=True,) + print("export origin onnx finished.") + + if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SIM_PATH): + print("exporting sim onnx model...") + onnx_model = onnx.load(ONNX_MODEL_ORIGIN_PATH) + onnx_model = simplify(onnx_model) + onnx.save(onnx_model, ONNX_MODEL_SIM_PATH, save_as_external_data=True, \ + location="llama2_sim_bs{}_layer{}.pb".format(args.batchsize, args.n_layers)) + print("exporting sim onnx model finished.") + + if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_FUSION_PATH): + print("exporting fusion onnx model...") + onnx_model = onnx.load(ONNX_MODEL_SIM_PATH) + onnx_model = fusion(onnx_model) + onnx.save(onnx_model, ONNX_MODEL_FUSION_PATH, save_as_external_data=True, \ + location="llama2_fusion_bs{}_layer{}.pb".format(args.batchsize, args.n_layers)) + print("exporting fusion onnx model finished.") + + if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SPECIAL_PATH): + print("exporting special onnx model...") + onnx_model = onnx.load(ONNX_MODEL_FUSION_PATH) + onnx_model = special_pass(onnx_model) + onnx.save(onnx_model, ONNX_MODEL_SPECIAL_PATH, save_as_external_data=True, \ + location="llama2_special_bs{}_layer{}.pb".format(args.batchsize, args.n_layers)) + print("exporting special onnx model finished.") + + if not args.is_1st_graph and args.fp16 and not os.path.exists(ONNX_MODEL_FP16_PATH): + print("exporting fp16 onnx model...") + onnx_model = onnx.load(ONNX_MODEL_SPECIAL_PATH) + onnx_model = convert_to_fp16(onnx_model) + onnx.save(onnx_model, ONNX_MODEL_FP16_PATH, save_as_external_data=True, \ + location="llama2_fp16_bs{}_layer{}.pb".format(args.batchsize, args.n_layers)) + print("exporting fp16 onnx model finished.") + + print("world_size =", args.world_size) + if not args.is_1st_graph and args.world_size > 1 and not os.path.exists(ONNX_MODEL_DIST_PATH): + print("exporting dist onnx model...") + onnx_model = onnx.load(ONNX_MODEL_FP16_PATH) if args.fp16 else onnx.load(ONNX_MODEL_SPECIAL_PATH) + onnx_model = parallel_model(onnx_model, args.world_size, args.rank) + onnx.save(onnx_model, ONNX_MODEL_DIST_PATH, save_as_external_data=True, \ + location="llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format( + args.batchsize, args.n_layers, + 16 if args.fp16 else 32, args.world_size, args.rank)) + print("exporting dist onnx model finished.") + +def get_it_logit(onnx_model, input_ids): + # initialization + runtime = backend.CudaRuntime(args.rank) + runtime.init_comm( + "dist", + args.world_size, + args.rank, + ) + print("[{}] comm init.".format(args.rank)) + stub = OnnxStub(onnx_model, runtime) + print("[{}] stub init.".format(args.rank)) + + # warm up + for i in range(10): + if args.no_cudagraph: + stub.run() + else: + stub.run_with_cudagraph() + print("[{}] stub warmup.".format(args.rank)) + + logits = np.zeros((args.batchsize, args.n_max_length, args.vocab_size), dtype=np.float32) + output_ids = np.zeros((args.batchsize, args.n_max_length), dtype=np.int64) + avg_inference_time = 0 + t0 = time.time() + for i in tqdm(range(0, args.n_max_length)): + with nvtx.annotate("seq_length = {}".format(i), color="red"): + assert input_ids.shape[0] == args.batchsize + input_id = input_ids[:, i] if i < input_ids.shape[1] else output_ids[:, i-1] + position_id = i*np.ones((args.batchsize, 1), dtype=np.int32) + + # copyin input + with nvtx.annotate("[it] copyin", color="blue"): + (list(stub.inputs.items()))[0][1].copyin_int64( + input_id.reshape(-1).tolist()) + (list(stub.inputs.items()))[1][1].copyin_int64( + position_id.reshape(-1).tolist()) + + # run + t10 = time.time() + with nvtx.annotate("[it] run", color="green"): + if args.no_cudagraph: + stub.run() + else: + stub.run_with_cudagraph() + t11 = time.time() + avg_inference_time += (t11 - t10) + + # copyout output + if not args.speedup: + with nvtx.annotate("[it] copyout", color="blue"): + logits[:,i, :] = np.array((list(stub.outputs.items()))[0][1].copyout_float()).reshape(args.batchsize, -1) + output_ids[:, i] = np.argmax(logits[:, i, :], -1).astype(np.int64) + + + t1 = time.time() + if args.rank == 0: + result = "[it] e2e: {} gpus, {} layers, e2e time: {:.2f}s, average inference time: {:.2f}ms"\ + .format(args.num_nodes * args.nproc_per_node, args.n_layers, t1-t0, \ + avg_inference_time*1000/args.n_max_length) + print(result) + del stub + return output_ids + +if __name__ == "__main__": + torch_model = LlamaForCausalLM.from_pretrained( + PRETRAINED_LLAMA_PATH, num_hidden_layers=int(args.n_layers)).eval() + tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LLAMA_PATH) + #prompt = "Hey, are you conscious? Can you talk to me?" + #prompt = "What is PTX?" + #prompt = "Tell me a joke." + #prompt = "What are the key principles of smart investing?" + prompt = "What is DeepSpeed?" + prompts=[prompt]*args.batchsize + inputs = tokenizer(prompts, return_tensors="pt") + + input_ids = inputs.input_ids + print("prompt ids =", input_ids) + + ########################################################## + # inference with InfiniTensor + ########################################################## + print("exporting onnx...") + export_onnx(torch_model) + print("exporting onnx finished.") + + onnx_to_run_path = ONNX_MODEL_DIST_PATH if args.world_size > 1 else \ + (ONNX_MODEL_FP16_PATH if args.fp16 else ONNX_MODEL_SPECIAL_PATH) + print("loading onnx", onnx_to_run_path, "...") + onnx_model = onnx.load(onnx_to_run_path) + print("loading onnx finished.") + output_ids_it = get_it_logit(onnx_model, input_ids) + it_output_text = tokenizer.batch_decode(output_ids_it[:, input_ids.shape[-1]:output_ids_it.shape[-1]]) + if args.rank == 0: + for i in range(args.batchsize): + print("prompt: ", prompts[i]) + print("answer: [it]", it_output_text[i]) + + ########################################################## + # validation with pytorch + ########################################################## + """ + generate_ids = torch_model.generate(inputs.input_ids, max_length=args.n_max_length)#, num_beams=4, do_sample=True) + outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + """ + if not args.speedup and not args.is_1st_graph: + kvcache_torch = None + output_ids_pt = torch.zeros(args.batchsize, args.n_max_length).int() # + input_ids.shape[-1] - 1).int() + if args.fp16: + torch_model = torch_model.half() + + torch_model = torch_model.cuda() + # print(torch.cuda.memory_summary()) + + avg_inference_time = 0 + with torch.no_grad(): + t0 = time.time() + for i in range(args.n_max_length): + input_id = input_ids[:,i] if i < input_ids.shape[1] else out_token + input_id = input_id.view(args.batchsize,1).cuda() + t00 = time.time() + outputs = torch_model(input_id, past_key_values=kvcache_torch) + t01 = time.time() + avg_inference_time += (t01-t00) + + logits = outputs['logits'] + kvcache_torch = outputs['past_key_values'] + out_token = torch.argmax(logits, dim=-1) + output_ids_pt[:, i:i+1] = out_token + t1 = time.time() + avg_inference_time /= args.n_max_length + result = "[pt] e2e time: {:.2f}s, average inference time: {:.2f}ms"\ + .format(t1-t0, avg_inference_time*1000) + + if args.rank == 0: + print(result) + pt_output_text = tokenizer.batch_decode(output_ids_pt[:,input_ids.shape[-1]:args.n_max_length]) + for i in range(args.batchsize): + print("[pt]", args.rank, pt_output_text[i]) + + if not args.is_1st_graph: + assert(output_ids_it.shape[-1] == args.n_max_length) + np.testing.assert_equal(output_ids_pt[:, input_ids.shape[-1]:args.n_max_length], output_ids_it[:,input_ids.shape[-1]:args.n_max_length])