forked from jiuyuan/InfiniTensor
add test scripts for llama2 and 9G models
This commit is contained in:
parent
159642d6ae
commit
4a5b9572bb
|
@ -0,0 +1,512 @@
|
|||
import os
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import onnx_graphsurgeon as gs
|
||||
import time
|
||||
import nvtx
|
||||
import argparse
|
||||
from mpi4py import MPI
|
||||
from pytrie import StringTrie
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
IO,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||
parser.add_argument('--layer', dest='n_layers', type=int, default=48)
|
||||
parser.add_argument("--num_nodes", dest='num_nodes',
|
||||
type=int, default=1, help="number of nodes")
|
||||
parser.add_argument("--world_size", dest="world_size",
|
||||
type=int, default=1, help="")
|
||||
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
|
||||
type=int, default=1, help="number of processes per node")
|
||||
parser.add_argument("--n_max_length", dest="n_max_length",
|
||||
type=int, default=1024, help="number of processes per node")
|
||||
parser.add_argument("--vocab_size", dest="vocab_size",
|
||||
type=int, default=119696, help="vocabulary size")
|
||||
parser.add_argument("--hidden_size", dest="hidden_size",
|
||||
type=int, default=4096, help="vocabulary size")
|
||||
parser.add_argument('--rank', dest='rank', type=int, default=0)
|
||||
parser.add_argument('--speedup', action='store_true')
|
||||
parser.add_argument('--no_cudagraph', action='store_true')
|
||||
parser.add_argument('--fp16', action='store_true')
|
||||
args = parser.parse_args()
|
||||
comm = MPI.COMM_WORLD
|
||||
args.rank = comm.Get_rank()
|
||||
args.nproc_per_node = comm.Get_size()
|
||||
args.world_size = args.num_nodes * args.nproc_per_node
|
||||
|
||||
ONNX_MODEL_PATH = "/data3/shared/xnsong/9G/dist/9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
|
||||
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||
|
||||
weight_path = "9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
|
||||
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||
|
||||
model_dir = "/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/cpm9g-11b-sft.pt"
|
||||
|
||||
@gs.Graph.register()
|
||||
def RMSNorm(self, a, b):
|
||||
return self.layer(op="RMSNorm", inputs=a, outputs=b)
|
||||
|
||||
@gs.Graph.register()
|
||||
def RoPE(self, a, b):
|
||||
return self.layer(op="RoPE", inputs=a, outputs=b)
|
||||
|
||||
@gs.Graph.register()
|
||||
def AttentionKVCache(self, a, b):
|
||||
return self.layer(op="AttentionKVCache", inputs=a, outputs=b)
|
||||
|
||||
def to_numpy(dict):
|
||||
ret = dict
|
||||
if args.fp16:
|
||||
ret = np.float16(ret)
|
||||
else:
|
||||
ret = np.float32(ret)
|
||||
return ret
|
||||
|
||||
def parallel(array, split='replicate'):
|
||||
if args.world_size > 1 and split == 'partial_column':
|
||||
return np.hsplit(array, args.world_size)[args.rank]
|
||||
elif args.world_size > 1 and split == 'partial_row':
|
||||
return np.vsplit(array, args.world_size)[args.rank]
|
||||
return array
|
||||
|
||||
|
||||
def generate_onnx(ONNX_MODEL_PATH):
|
||||
state_dict = torch.load(f'{model_dir}', map_location='cpu')
|
||||
new_state_dict = {name: param.cpu().numpy()
|
||||
for name, param in state_dict.items()
|
||||
}
|
||||
|
||||
operators = []
|
||||
graph = gs.Graph(nodes=operators)
|
||||
gather_input = gs.Variable(name="gather_input.0", dtype=np.int64, shape=(1,1))
|
||||
pos_input = gs.Variable(name="pos_input.0", dtype=np.int64, shape=(1,1))
|
||||
|
||||
embedding_weight = gs.Constant(name="embedding.weight", values=to_numpy(new_state_dict["input_embedding.weight"]))
|
||||
gather_output = gs.Variable(name="gather_output.0")
|
||||
gather = gs.Node(op="Gather", inputs=[embedding_weight, gather_input], outputs=[gather_output])
|
||||
operators.append(gather)
|
||||
input = gather_output
|
||||
|
||||
graph.inputs=[gather_input, pos_input]
|
||||
graph.outputs=[]
|
||||
|
||||
for i in tqdm(range(args.n_layers)):
|
||||
# global input
|
||||
attn_kcache_input = gs.Variable(name="/layers." + str(i) + "/attn/kcache_input", dtype=np.float32, shape=(1,32,1023,128))
|
||||
attn_vcache_input = gs.Variable(name="/layers." + str(i) + "/attn/vcache_input", dtype=np.float32, shape=(1,32,1023,128))
|
||||
graph.inputs.append(attn_kcache_input)
|
||||
graph.inputs.append(attn_vcache_input)
|
||||
|
||||
# weight
|
||||
layernorm_0_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.0/mul_weight",
|
||||
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".self_att.layernorm_before_attention.weight"]))
|
||||
attn_qproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/qproj_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_q.weight"]))
|
||||
, 'partial_column'))
|
||||
attn_kproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/kproj_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_k.weight"]))
|
||||
, 'partial_column'))
|
||||
attn_vproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/vproj_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_v.weight"]))
|
||||
, 'partial_column'))
|
||||
attn_outmatmul_input = gs.Constant(name="/layers." + str(i) + "/attn/outmatmul_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.attention_out.weight"]))
|
||||
, 'partial_row'))
|
||||
|
||||
layernorm_1_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.1/mul_weight",
|
||||
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".ffn.layernorm_before_ffn.weight"]))
|
||||
ffn_matmul_0_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_0_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_0.weight"]))
|
||||
, 'partial_column'))
|
||||
ffn_matmul_1_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_1_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_1.weight"]))
|
||||
, 'partial_column'))
|
||||
ffn_matmul_out_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_out_weight",
|
||||
values=parallel(
|
||||
np.transpose(
|
||||
to_numpy(
|
||||
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_out.weight"]))
|
||||
, 'partial_row'))
|
||||
|
||||
attn_qrope_output = gs.Variable(name="/layers." + str(i) + "/attn/qrope_output")
|
||||
attn_krope_output = gs.Variable(name="/layers." + str(i) + "/attn/krope_output")
|
||||
attn_kvcache_output = gs.Variable(name="/layers." + str(i) + "/attn/kvcache_output")
|
||||
layernorm_0_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.0/mul_output_1")
|
||||
layernorm_1_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.1/mul_output_1")
|
||||
attn_qproj_output = gs.Variable(name="/layers." + str(i) + "/attn/qproj_output")
|
||||
attn_kproj_output = gs.Variable(name="/layers." + str(i) + "/attn/kproj_output")
|
||||
attn_vproj_output = gs.Variable(name="/layers." + str(i) + "/attn/vproj_output")
|
||||
attn_outmatmul_output = gs.Variable(name="/layers." + str(i) + "/attn/outmatmul_output")
|
||||
attn_outadd_output = gs.Variable(name="/layers." + str(i) + "/attn/outadd_output")
|
||||
ffn_matmul_0_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_0_output")
|
||||
ffn_silu_output = gs.Variable(name="/layers." + str(i) + "/ffn/silu_output")
|
||||
ffn_matmul_1_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_1_output")
|
||||
ffn_mul_output = gs.Variable(name="/layers." + str(i) + "/ffn/mul_output")
|
||||
ffn_matmul_out_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_out_output")
|
||||
ffn_add_output = gs.Variable(name="/layers." + str(i) + "/ffn/add_output")
|
||||
|
||||
graph.RMSNorm([input, layernorm_0_mul_weight], [layernorm_0_mul_output_1])
|
||||
attn_qproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_qproj_weight], outputs=[attn_qproj_output])
|
||||
operators.append(attn_qproj)
|
||||
attn_kproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_kproj_weight], outputs=[attn_kproj_output])
|
||||
operators.append(attn_kproj)
|
||||
attn_vproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_vproj_weight], outputs=[attn_vproj_output])
|
||||
operators.append(attn_vproj)
|
||||
graph.RoPE([pos_input, attn_qproj_output], [attn_qrope_output])
|
||||
graph.RoPE([pos_input, attn_kproj_output], [attn_krope_output])
|
||||
graph.AttentionKVCache([attn_kcache_input, attn_vcache_input, attn_qrope_output, attn_krope_output, attn_vproj_output, pos_input],[attn_kvcache_output])
|
||||
attn_outproj = gs.Node(op="MatMul", inputs=[attn_kvcache_output, attn_outmatmul_input], outputs=[attn_outmatmul_output])
|
||||
operators.append(attn_outproj)
|
||||
|
||||
attn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/attn/reducesum_output")
|
||||
if args.world_size > 1:
|
||||
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/attn/reducesum",
|
||||
inputs=[attn_outmatmul_output], outputs=[attn_reduce_sum_output],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum)
|
||||
|
||||
attn_outadd = gs.Node(op="Add", inputs=[input, attn_outmatmul_output if args.world_size == 1 else attn_reduce_sum_output], outputs=[attn_outadd_output])
|
||||
operators.append(attn_outadd)
|
||||
|
||||
graph.RMSNorm([attn_outadd_output, layernorm_1_mul_weight], [layernorm_1_mul_output_1])
|
||||
|
||||
ffn_matmul_0 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_0_input], outputs=[ffn_matmul_0_output])
|
||||
operators.append(ffn_matmul_0)
|
||||
ffn_silu = gs.Node(op="Silu", inputs=[ffn_matmul_0_output], outputs=[ffn_silu_output])
|
||||
operators.append(ffn_silu)
|
||||
ffn_matmul_1 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_1_input], outputs=[ffn_matmul_1_output])
|
||||
operators.append(ffn_matmul_1)
|
||||
ffn_mul = gs.Node(op="Mul", inputs=[ffn_silu_output, ffn_matmul_1_output], outputs=[ffn_mul_output])
|
||||
operators.append(ffn_mul)
|
||||
ffn_matmul_out = gs.Node(op="MatMul", inputs=[ffn_mul_output, ffn_matmul_out_input], outputs=[ffn_matmul_out_output])
|
||||
operators.append(ffn_matmul_out)
|
||||
|
||||
ffn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/ffn/reducesum_output")
|
||||
if args.world_size > 1:
|
||||
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/ffn/reducesum",
|
||||
inputs=[ffn_matmul_out_output], outputs=[ffn_reduce_sum_output],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum)
|
||||
|
||||
ffn_add = gs.Node(op="Add", inputs=[attn_outadd_output, ffn_matmul_out_output if args.world_size == 1 else ffn_reduce_sum_output], outputs=[ffn_add_output])
|
||||
operators.append(ffn_add)
|
||||
input = ffn_add_output
|
||||
|
||||
layernorm_mul_weight = gs.Constant(name="/output/layernorm/mul_weight", values=to_numpy(new_state_dict["encoder.output_layernorm.weight"]))
|
||||
layernorm_mul_output_1 = gs.Variable(name="/output/layernorm/mul_output_1")
|
||||
|
||||
graph.RMSNorm([input, layernorm_mul_weight], [layernorm_mul_output_1])
|
||||
|
||||
lm_head_weight = gs.Constant(name="/output/lm_head/weight", values=np.transpose(to_numpy(new_state_dict["lm_head.weight"])))
|
||||
lm_head_output = gs.Variable(name="/output/lm_head/output")
|
||||
lm_head = gs.Node(op="MatMul", inputs=[layernorm_mul_output_1, lm_head_weight], outputs=[lm_head_output])
|
||||
operators.append(lm_head)
|
||||
|
||||
if args.fp16:
|
||||
final_cast_output = gs.Variable(name="/output/cast/output", dtype=np.float32, shape=(1,1,args.vocab_size))
|
||||
final_cast = gs.Node(op="Cast", inputs=[lm_head_output], outputs=[final_cast_output])
|
||||
final_cast.attrs["to"] = np.float32
|
||||
operators.append(final_cast)
|
||||
graph.outputs.append(final_cast_output)
|
||||
else:
|
||||
lm_head_output.dtype=np.float32
|
||||
lm_head_output.shape=(1,1,args.vocab_size)
|
||||
graph.outputs.append(lm_head_output)
|
||||
|
||||
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True, location=weight_path)
|
||||
return
|
||||
|
||||
|
||||
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab: Dict[str, int] = {}
|
||||
|
||||
reader = io.TextIOWrapper(fp, encoding="utf-8")
|
||||
for token in reader.readlines():
|
||||
token = token.strip()
|
||||
if len(token) == 0:
|
||||
continue
|
||||
token = json.loads(token)
|
||||
vocab[token] = len(vocab)
|
||||
return vocab
|
||||
|
||||
|
||||
class CPM9GTokenizer(object):
|
||||
def __init__(self, path):
|
||||
self.unk_token = "<unk>"
|
||||
self.bos_token = "<s>"
|
||||
self.eos_token = "</s>"
|
||||
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
|
||||
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
|
||||
]
|
||||
|
||||
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
|
||||
|
||||
all_tokens = load_vocab(io.FileIO(path, "rb"))
|
||||
|
||||
self.encoder: Dict[str, int] = {}
|
||||
self._special_encoder: Dict[str, int] = {}
|
||||
for token, token_id in all_tokens.items():
|
||||
if token in self._special_token_set:
|
||||
self._special_encoder[token] = token_id
|
||||
else:
|
||||
self.encoder[token] = token_id
|
||||
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
|
||||
|
||||
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
||||
|
||||
self._len_word_first = {}
|
||||
for x in self.encoder.keys():
|
||||
if not x[0] in self._len_word_first:
|
||||
self._len_word_first[x[0]] = 1
|
||||
if len(x) > self._len_word_first[x[0]]:
|
||||
self._len_word_first[x[0]] = len(x)
|
||||
self.tencoder = StringTrie(self.encoder)
|
||||
|
||||
def get_piece(self, text: str) -> str:
|
||||
if text[0] in self._len_word_first:
|
||||
text = text[: self._len_word_first[text[0]]]
|
||||
len_text = len(text)
|
||||
for i in range(len(text)):
|
||||
sub = text[: len_text - i]
|
||||
if sub in self.encoder:
|
||||
return sub
|
||||
return text[0]
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self._special_encoder[self.eos_token]
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self._special_encoder[self.bos_token]
|
||||
|
||||
@property
|
||||
def unk_id(self):
|
||||
return self._special_encoder[self.unk_token]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self._special_encoder)
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
output_tokens: List[str] = []
|
||||
st = 0
|
||||
while st < len(text):
|
||||
piece = self.get_piece(text[st:])
|
||||
output_tokens.append(piece)
|
||||
st += len(piece)
|
||||
return output_tokens
|
||||
|
||||
@staticmethod
|
||||
def escape(text: str) -> str:
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def unescape(text: str) -> str:
|
||||
return text
|
||||
|
||||
def encode(self, text: str, with_bos = True) -> List[int]:
|
||||
ret = []
|
||||
if with_bos:
|
||||
ret.append(self.bos_id)
|
||||
for x in self.tokenize(text):
|
||||
if x in self.encoder:
|
||||
ret.append(self.encoder[x])
|
||||
else:
|
||||
ret.extend(self._encode_unicode(x))
|
||||
return ret
|
||||
|
||||
def decode(self, tokens: List[int]):
|
||||
"""Decode ids into a string."""
|
||||
ret = []
|
||||
st = 0
|
||||
while st < len(tokens):
|
||||
if tokens[st] in self.decoder:
|
||||
ret.append(self.decoder[tokens[st]])
|
||||
st += 1
|
||||
elif tokens[st] in self._byte_decoder:
|
||||
first = self._byte_decoder[tokens[st]]
|
||||
length = 1 if first < 128 else len(re.search('^1+0', bin(first)[2:])[0])-1
|
||||
code = 0
|
||||
try:
|
||||
for j in range(length):
|
||||
code = code << 8 | self._byte_decoder[tokens[st + j]]
|
||||
code = int.to_bytes(code, length, "big").decode("utf-8")
|
||||
ret.append(code)
|
||||
except:
|
||||
pass
|
||||
st = st + length
|
||||
elif tokens[st] == self.eos_id:
|
||||
ret.append(self.eos_token)
|
||||
st += 1
|
||||
elif tokens[st] == self.bos_id:
|
||||
ret.append(self.bos_token)
|
||||
st += 1
|
||||
else:
|
||||
ret.append(self.unk_token)
|
||||
st += 1
|
||||
return "".join(ret)
|
||||
|
||||
def _encode_unicode(self, token):
|
||||
# wrap unicode encoding into a helper function
|
||||
ids = []
|
||||
utf8_id = token.encode("utf-8")
|
||||
for _id in utf8_id:
|
||||
ids.append(self._special_encoder[self.byte_list[_id]])
|
||||
return ids
|
||||
|
||||
def next_token(self, text):
|
||||
# fast next token matching
|
||||
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
|
||||
if token is None:
|
||||
token = text[0]
|
||||
token_ids = self._encode_unicode(token)
|
||||
else:
|
||||
token_ids = [token_id]
|
||||
return token, token_ids
|
||||
|
||||
|
||||
def start_worker(
|
||||
world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, query
|
||||
):
|
||||
model = onnx.load(ONNX_MODEL_PATH)
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
if args.nproc_per_node > 1:
|
||||
runtime.init_comm(
|
||||
"9g",
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
print("[{}] comm init.".format(rank))
|
||||
|
||||
stub = OnnxStub(model, runtime)
|
||||
print("[{}] stub init.".format(rank))
|
||||
|
||||
for i in range(10):
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
print("[{}] stub warmup.".format(rank))
|
||||
|
||||
tokenizer = CPM9GTokenizer("/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/vocabs.txt")
|
||||
query = tokenizer.encode(query)
|
||||
|
||||
output_tokens = []
|
||||
for i in range(len(query)):
|
||||
q = np.array(query[i])
|
||||
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
|
||||
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
|
||||
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
|
||||
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
|
||||
if i == len(query) - 1:
|
||||
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
|
||||
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
|
||||
q = np.argmax(output)
|
||||
output_tokens.append(q)
|
||||
|
||||
avg_time = 0
|
||||
count = 0
|
||||
while i < 1000:
|
||||
count = count + 1
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("gen {}-th token".format(i), color="red"):
|
||||
i = i + 1
|
||||
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
|
||||
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
|
||||
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
|
||||
|
||||
t0 = time.time()
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
t1 = time.time()
|
||||
avg_time += t1 - t0
|
||||
|
||||
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
|
||||
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
|
||||
|
||||
# print(output)
|
||||
|
||||
with nvtx.annotate("argmax".format(i), color="green"):
|
||||
q = np.argmax(output)
|
||||
if q == 2:
|
||||
break
|
||||
|
||||
output_tokens.append(q)
|
||||
avg_time = avg_time / count
|
||||
print("avg_time_cost =", avg_time*1000, "ms")
|
||||
text = tokenizer.decode(output_tokens)
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
comm = MPI.COMM_WORLD
|
||||
args.rank = comm.Get_rank()
|
||||
args.nproc_per_node = comm.Get_size()
|
||||
world_size = args.num_nodes * args.nproc_per_node
|
||||
|
||||
if not os.path.exists(ONNX_MODEL_PATH):
|
||||
print("exporting onnx graph")
|
||||
generate_onnx(ONNX_MODEL_PATH)
|
||||
else:
|
||||
print("will use exsiting onnx graph")
|
||||
onnx_model = onnx.load(ONNX_MODEL_PATH)
|
||||
print("data loaded")
|
||||
|
||||
|
||||
#query = '''Beijing is the captial'''
|
||||
#query = '''什么是PTX?'''
|
||||
#query = '''生病了怎么办?'''
|
||||
#query = '''Happy'''
|
||||
query = '''def gcd(a, b):'''
|
||||
|
||||
####################################
|
||||
# infinitensor dist
|
||||
####################################
|
||||
# run distributed parallel.
|
||||
pred = start_worker(world_size, args.rank, args.rank %
|
||||
args.nproc_per_node, onnx_model, query)
|
||||
if args.rank == 0:
|
||||
print("输入:\n\n", query, "\n")
|
||||
print("输出:", pred)
|
|
@ -0,0 +1,491 @@
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import torch
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import os
|
||||
import numpy as np
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import time
|
||||
import nvtx
|
||||
from mpi4py import MPI
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||
parser.add_argument('--layer', dest='n_layers', type=int, default=32)
|
||||
parser.add_argument("--num_nodes", dest='num_nodes',
|
||||
type=int, default=1, help="number of nodes")
|
||||
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
|
||||
type=int, default=1, help="number of processes per node")
|
||||
parser.add_argument("--world_size", dest="world_size",
|
||||
type=int, default=1, help="")
|
||||
parser.add_argument("--n_max_length", dest="n_max_length",
|
||||
type=int, default=1024, help="")
|
||||
parser.add_argument("--vocab_size", dest="vocab_size",
|
||||
type=int, default=32000, help="vocabulary size")
|
||||
parser.add_argument("--hidden_size", dest="hidden_size",
|
||||
type=int, default=4096)
|
||||
parser.add_argument("--head_size", dest="head_size",
|
||||
type=int, default=32)
|
||||
parser.add_argument("--head_dim", dest="head_dim",
|
||||
type=int, default=128)
|
||||
parser.add_argument('--rank', dest='rank', type=int, default=0)
|
||||
parser.add_argument('--no_cudagraph', action='store_true')
|
||||
parser.add_argument('--fp16', action='store_true')
|
||||
parser.add_argument('--is_1st_graph', action='store_true')
|
||||
parser.add_argument('--speedup', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
args.rank = comm.Get_rank()
|
||||
args.nproc_per_node = comm.Get_size()
|
||||
args.world_size = args.num_nodes * args.nproc_per_node
|
||||
|
||||
PRETRAINED_LLAMA_PATH = "/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/"
|
||||
ONNX_MODEL_PATH = "/data3/shared/xnsong/llama2/" + ("1st" if args.is_1st_graph else "2nd")
|
||||
ONNX_MODEL_ORIGIN_PATH = ONNX_MODEL_PATH + "/origin/llama2_origin_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_SIM_PATH = ONNX_MODEL_PATH + "/sim/llama2_sim_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_FUSION_PATH = ONNX_MODEL_PATH + "/fusion/llama2_fusion_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_SPECIAL_PATH = ONNX_MODEL_PATH + "/special/llama2_special_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_FP16_PATH = ONNX_MODEL_PATH + "/fp16/llama2_fp16_bs{}_layer{}.onnx".format(
|
||||
args.batchsize, args.n_layers)
|
||||
ONNX_MODEL_DIST_PATH = ONNX_MODEL_PATH + "/dist/llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
|
||||
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||
|
||||
def parallel_model(onnx_model, world_size, rank):
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
tmap = graph.tensors()
|
||||
|
||||
for i in range(args.n_layers):
|
||||
tmap[graph.inputs[2+i*2].name].shape[1] = tmap[graph.inputs[2+i*2].name].shape[1]//world_size
|
||||
tmap[graph.inputs[3+i*2].name].shape[1] = tmap[graph.inputs[3+i*2].name].shape[1]//world_size
|
||||
for node in graph.nodes:
|
||||
if node.name == "/model/layers." + str(i) + "/self_attn/q_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/k_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/v_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/o_proj/MatMul":
|
||||
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
|
||||
reduce_sum_output = gs.Variable("reduce_sum_output_" + str(i) + "_0",
|
||||
dtype=np.float32)
|
||||
reduce_sum = gs.Node(op="ReduceSum", name="reduce_sum_"+str(i)+"_0",
|
||||
inputs=node.outputs, outputs=[reduce_sum_output],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum)
|
||||
next_node = node.outputs[0].outputs[0]
|
||||
next_node.inputs[1] = reduce_sum_output
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_0" or \
|
||||
node.name == "/model/layers." + str(i) + "/self_attn/Reshape_1":
|
||||
node.inputs[1].values = np.array(
|
||||
[1, 1,
|
||||
args.head_size//world_size,
|
||||
args.hidden_size//args.head_size])
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_2":
|
||||
node.inputs[1] = gs.Constant(name="/model/layers."+str(i)+"/self_attn/vreshape_input",
|
||||
values=np.array(
|
||||
[1, 1,
|
||||
args.head_size//world_size,
|
||||
args.hidden_size//args.head_size]))
|
||||
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_3":
|
||||
node.inputs[1] = gs.Constant(name="/model/layers." + str(i) + "/self_attn/Reshape_3_shape",
|
||||
values=np.array(
|
||||
[1, 1, args.hidden_size//world_size]))
|
||||
|
||||
elif node.name == "/model/layers." + str(i) + "/mlp/up_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/mlp/gate_proj/MatMul":
|
||||
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||
elif node.name == "/model/layers." + str(i) + "/mlp/down_proj/MatMul":
|
||||
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
|
||||
reduce_sum_output_1 = gs.Variable("reduce_sum_output_" + str(i) + "_1",
|
||||
dtype=np.float32)
|
||||
reduce_sum_1 = gs.Node(op="ReduceSum", inputs=node.outputs, outputs=[reduce_sum_output_1],
|
||||
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||
graph.nodes.append(reduce_sum_1)
|
||||
next_node = node.outputs[0].outputs[0]
|
||||
next_node.inputs[1] = reduce_sum_output_1
|
||||
|
||||
# new_out_1 = tmap["/model/layers.0/mlp/down_proj/MatMul_output_0"] #reduce_sum_output
|
||||
# new_out_1.dtype = np.float32
|
||||
# new_out_1.shape = [1,1,4096]
|
||||
# graph.outputs.append(new_out_1)
|
||||
graph.cleanup(True).toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def simplify(onnx_model):
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
for node in graph.nodes:
|
||||
if node.op == "Cast":
|
||||
inp_node = node.i()
|
||||
inp_node.outputs = node.outputs
|
||||
node.outputs.clear()
|
||||
|
||||
for i in range(args.n_layers):
|
||||
nodename = "/model/layers." + str(i) + "/self_attn/Add_2"
|
||||
node = [node for node in graph.nodes if node.name == nodename][0]
|
||||
inp_node = node.i()
|
||||
inp_node.outputs = node.outputs
|
||||
node.outputs.clear()
|
||||
|
||||
graph.cleanup().toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_RMSNorm(self, inputs, outputs):
|
||||
inputs[0].outputs.pop(0)
|
||||
inputs[0].outputs.pop(0)
|
||||
|
||||
for out in outputs:
|
||||
out.inputs.clear()
|
||||
return self.layer(op="RMSNorm", inputs=inputs, outputs=outputs, name="rmsnorm")
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_silu(self, inputs, outputs):
|
||||
for inp in inputs:
|
||||
inp.outputs.clear()
|
||||
for out in outputs:
|
||||
out.inputs.clear()
|
||||
return self.layer(op="Silu", inputs=inputs, outputs=outputs, name="silu")
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_RoPE(self, a, b):
|
||||
return self.layer(op="RoPE", inputs=a, outputs=b, name="rope")
|
||||
|
||||
@gs.Graph.register()
|
||||
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
|
||||
for inp in inputs:
|
||||
inp.outputs.clear()
|
||||
for out in outputs:
|
||||
out.inputs.clear()
|
||||
for inp in inputs_added:
|
||||
inputs.append(inp)
|
||||
for out in outputs_removed:
|
||||
out.inputs.clear()
|
||||
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs, name="attention")
|
||||
|
||||
def fusion(model):
|
||||
graph = gs.import_onnx(model)
|
||||
tmap = graph.tensors()
|
||||
|
||||
tmap["onnx::Reshape_1"].outputs.clear()
|
||||
|
||||
inputs = [tmap["/model/layers.0/input_layernorm/Cast_output_0"], tmap["model.layers.0.input_layernorm.weight"]]
|
||||
rmsnorm_outputs = [tmap["/model/layers.0/input_layernorm/Mul_1_output_0"]]
|
||||
graph.replace_with_RMSNorm(inputs, rmsnorm_outputs)
|
||||
|
||||
for i in range(args.n_layers):
|
||||
# rotary embedding op
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"].inputs.clear()
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"].inputs.clear()
|
||||
attn_qreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/qreshape_input",
|
||||
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
|
||||
attn_kreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/kreshape_input",
|
||||
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
|
||||
attn_qrope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qrope_output")
|
||||
attn_krope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/krope_output")
|
||||
attn_qreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qreshape_output")
|
||||
attn_kreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/kreshape_output")
|
||||
|
||||
attn_qreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_0", inputs=[attn_qrope_output, attn_qreshape_input], outputs=[attn_qreshape_output])
|
||||
attn_kreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_1", inputs=[attn_krope_output, attn_kreshape_input], outputs=[attn_kreshape_output])
|
||||
attn_qtrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_qreshape_output],
|
||||
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"]])
|
||||
attn_ktrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_kreshape_output],
|
||||
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"]])
|
||||
|
||||
graph.nodes.append(attn_qreshape)
|
||||
graph.nodes.append(attn_kreshape)
|
||||
graph.nodes.append(attn_qtrans)
|
||||
graph.nodes.append(attn_ktrans)
|
||||
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/q_proj/MatMul_output_0"]]
|
||||
graph.replace_with_RoPE(inputs, [attn_qrope_output])
|
||||
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/k_proj/MatMul_output_0"]]
|
||||
graph.replace_with_RoPE(inputs, [attn_krope_output])
|
||||
|
||||
# rms-norm op
|
||||
inputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Cast_output_0"], \
|
||||
tmap["model.layers." + str(i) + ".post_attention_layernorm.weight"]]
|
||||
outputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Mul_1_output_0"]]
|
||||
graph.replace_with_RMSNorm(inputs, outputs)
|
||||
inputs = [tmap["/model/layers." + str(i+1) + "/input_layernorm/Cast_output_0"] if i != args.n_layers-1 else \
|
||||
tmap["/model/norm/Cast_output_0"], \
|
||||
tmap["model.layers." + str(i+1) + ".input_layernorm.weight"] if i != args.n_layers-1 else \
|
||||
tmap["model.norm.weight"]]
|
||||
outputs = [tmap["/model/layers."+ str(i+1) + "/input_layernorm/Mul_1_output_0"]] if i != args.n_layers-1 else \
|
||||
[tmap["/model/norm/Mul_1_output_0"]]
|
||||
graph.replace_with_RMSNorm(inputs, outputs)
|
||||
|
||||
# silu op
|
||||
inputs = [tmap["/model/layers." + str(i) + "/mlp/gate_proj/MatMul_output_0"]]
|
||||
outputs = [tmap["/model/layers." + str(i) + "/mlp/act_fn/Mul_output_0"]]
|
||||
graph.replace_with_silu(inputs, outputs)
|
||||
|
||||
inputs = [
|
||||
tmap["onnx::Concat_" + str((i+1)*2)],
|
||||
tmap["onnx::Concat_" + str((i+1)*2+1)],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
||||
outputs = [
|
||||
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],]
|
||||
|
||||
inputs_added = [graph.inputs[1]]
|
||||
outputs_removed = []
|
||||
graph.replace_with_attention(
|
||||
inputs, outputs, inputs_added, outputs_removed)
|
||||
|
||||
graph.outputs = [tmap[graph.outputs[0].name]]
|
||||
graph.cleanup(True).toposort()
|
||||
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def special_pass(model):
|
||||
graph = gs.import_onnx(model)
|
||||
tmap = graph.tensors()
|
||||
for node in graph.nodes:
|
||||
if node.op == "Transpose" or node.op == "Reshape":
|
||||
inp_node = node.i()
|
||||
inp_node.outputs = node.outputs
|
||||
node.outputs.clear()
|
||||
graph.cleanup(True).toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def convert_to_fp16(model):
|
||||
graph = gs.import_onnx(model)
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.op == "Gather" and node.name == "/model/embed_tokens/Gather":
|
||||
node.inputs[0].values = np.float16(node.inputs[0].values)
|
||||
|
||||
if node.op == "RMSNorm":
|
||||
node.inputs[1].values = np.float16(node.inputs[1].values)
|
||||
|
||||
if node.op == "MatMul":
|
||||
node.inputs[1].values = np.float16(node.inputs[1].values)
|
||||
if node.name == "/lm_head/MatMul":
|
||||
cast_1_out = gs.Variable(node.name+"_cast_out_output_0", dtype=np.float32, shape=node.outputs[0].shape)
|
||||
cast_1 = gs.Node(op="Cast", inputs=[node.outputs[0]], outputs=[cast_1_out])
|
||||
cast_1.attrs["to"] = np.float32
|
||||
cast_1.name = node.name+"_cast_out_0"
|
||||
graph.nodes.append(cast_1)
|
||||
graph.outputs[0] = cast_1_out
|
||||
node.outputs[0].dtype = np.float16
|
||||
|
||||
graph.cleanup(True).toposort()
|
||||
return gs.export_onnx(graph)
|
||||
|
||||
def export_onnx(model: AutoModelForCausalLM):
|
||||
if not os.path.exists(ONNX_MODEL_ORIGIN_PATH):
|
||||
print("exporting origin onnx model...")
|
||||
with torch.no_grad():
|
||||
param = torch.zeros(
|
||||
(args.batchsize, model.config.max_position_embeddings-1), dtype=torch.long)
|
||||
logits = model(param, past_key_values=None)
|
||||
|
||||
if not args.is_1st_graph:
|
||||
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
|
||||
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
|
||||
"position_ids": param_kvcache}), \
|
||||
ONNX_MODEL_ORIGIN_PATH, verbose=False,
|
||||
do_constant_folding=True,)
|
||||
else:
|
||||
position_ids = torch.tile(torch.arange(0, model.config.max_position_embeddings-1), (args.batchsize, 1))
|
||||
attention_mask = torch.ones((args.batchsize, model.config.max_position_embeddings-1), dtype=torch.bool)
|
||||
torch.onnx.export(model, (param, {"attention_mask": attention_mask,
|
||||
"position_ids": position_ids}),\
|
||||
ONNX_MODEL_ORIGIN_PATH, verbose=False,
|
||||
do_constant_folding=True,)
|
||||
print("export origin onnx finished.")
|
||||
|
||||
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SIM_PATH):
|
||||
print("exporting sim onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_ORIGIN_PATH)
|
||||
onnx_model = simplify(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_SIM_PATH, save_as_external_data=True, \
|
||||
location="llama2_sim_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting sim onnx model finished.")
|
||||
|
||||
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_FUSION_PATH):
|
||||
print("exporting fusion onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_SIM_PATH)
|
||||
onnx_model = fusion(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_FUSION_PATH, save_as_external_data=True, \
|
||||
location="llama2_fusion_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting fusion onnx model finished.")
|
||||
|
||||
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SPECIAL_PATH):
|
||||
print("exporting special onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_FUSION_PATH)
|
||||
onnx_model = special_pass(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_SPECIAL_PATH, save_as_external_data=True, \
|
||||
location="llama2_special_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting special onnx model finished.")
|
||||
|
||||
if not args.is_1st_graph and args.fp16 and not os.path.exists(ONNX_MODEL_FP16_PATH):
|
||||
print("exporting fp16 onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_SPECIAL_PATH)
|
||||
onnx_model = convert_to_fp16(onnx_model)
|
||||
onnx.save(onnx_model, ONNX_MODEL_FP16_PATH, save_as_external_data=True, \
|
||||
location="llama2_fp16_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||
print("exporting fp16 onnx model finished.")
|
||||
|
||||
print("world_size =", args.world_size)
|
||||
if not args.is_1st_graph and args.world_size > 1 and not os.path.exists(ONNX_MODEL_DIST_PATH):
|
||||
print("exporting dist onnx model...")
|
||||
onnx_model = onnx.load(ONNX_MODEL_FP16_PATH) if args.fp16 else onnx.load(ONNX_MODEL_SPECIAL_PATH)
|
||||
onnx_model = parallel_model(onnx_model, args.world_size, args.rank)
|
||||
onnx.save(onnx_model, ONNX_MODEL_DIST_PATH, save_as_external_data=True, \
|
||||
location="llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
|
||||
args.batchsize, args.n_layers,
|
||||
16 if args.fp16 else 32, args.world_size, args.rank))
|
||||
print("exporting dist onnx model finished.")
|
||||
|
||||
def get_it_logit(onnx_model, input_ids):
|
||||
# initialization
|
||||
runtime = backend.CudaRuntime(args.rank)
|
||||
runtime.init_comm(
|
||||
"dist",
|
||||
args.world_size,
|
||||
args.rank,
|
||||
)
|
||||
print("[{}] comm init.".format(args.rank))
|
||||
stub = OnnxStub(onnx_model, runtime)
|
||||
print("[{}] stub init.".format(args.rank))
|
||||
|
||||
# warm up
|
||||
for i in range(10):
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
print("[{}] stub warmup.".format(args.rank))
|
||||
|
||||
logits = np.zeros((args.batchsize, args.n_max_length, args.vocab_size), dtype=np.float32)
|
||||
output_ids = np.zeros((args.batchsize, args.n_max_length), dtype=np.int64)
|
||||
avg_inference_time = 0
|
||||
t0 = time.time()
|
||||
for i in tqdm(range(0, args.n_max_length)):
|
||||
with nvtx.annotate("seq_length = {}".format(i), color="red"):
|
||||
assert input_ids.shape[0] == args.batchsize
|
||||
input_id = input_ids[:, i] if i < input_ids.shape[1] else output_ids[:, i-1]
|
||||
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
|
||||
|
||||
# copyin input
|
||||
with nvtx.annotate("[it] copyin", color="blue"):
|
||||
(list(stub.inputs.items()))[0][1].copyin_int64(
|
||||
input_id.reshape(-1).tolist())
|
||||
(list(stub.inputs.items()))[1][1].copyin_int64(
|
||||
position_id.reshape(-1).tolist())
|
||||
|
||||
# run
|
||||
t10 = time.time()
|
||||
with nvtx.annotate("[it] run", color="green"):
|
||||
if args.no_cudagraph:
|
||||
stub.run()
|
||||
else:
|
||||
stub.run_with_cudagraph()
|
||||
t11 = time.time()
|
||||
avg_inference_time += (t11 - t10)
|
||||
|
||||
# copyout output
|
||||
if not args.speedup:
|
||||
with nvtx.annotate("[it] copyout", color="blue"):
|
||||
logits[:,i, :] = np.array((list(stub.outputs.items()))[0][1].copyout_float()).reshape(args.batchsize, -1)
|
||||
output_ids[:, i] = np.argmax(logits[:, i, :], -1).astype(np.int64)
|
||||
|
||||
|
||||
t1 = time.time()
|
||||
if args.rank == 0:
|
||||
result = "[it] e2e: {} gpus, {} layers, e2e time: {:.2f}s, average inference time: {:.2f}ms"\
|
||||
.format(args.num_nodes * args.nproc_per_node, args.n_layers, t1-t0, \
|
||||
avg_inference_time*1000/args.n_max_length)
|
||||
print(result)
|
||||
del stub
|
||||
return output_ids
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_model = LlamaForCausalLM.from_pretrained(
|
||||
PRETRAINED_LLAMA_PATH, num_hidden_layers=int(args.n_layers)).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LLAMA_PATH)
|
||||
#prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
#prompt = "What is PTX?"
|
||||
#prompt = "Tell me a joke."
|
||||
#prompt = "What are the key principles of smart investing?"
|
||||
prompt = "What is DeepSpeed?"
|
||||
prompts=[prompt]*args.batchsize
|
||||
inputs = tokenizer(prompts, return_tensors="pt")
|
||||
|
||||
input_ids = inputs.input_ids
|
||||
print("prompt ids =", input_ids)
|
||||
|
||||
##########################################################
|
||||
# inference with InfiniTensor
|
||||
##########################################################
|
||||
print("exporting onnx...")
|
||||
export_onnx(torch_model)
|
||||
print("exporting onnx finished.")
|
||||
|
||||
onnx_to_run_path = ONNX_MODEL_DIST_PATH if args.world_size > 1 else \
|
||||
(ONNX_MODEL_FP16_PATH if args.fp16 else ONNX_MODEL_SPECIAL_PATH)
|
||||
print("loading onnx", onnx_to_run_path, "...")
|
||||
onnx_model = onnx.load(onnx_to_run_path)
|
||||
print("loading onnx finished.")
|
||||
output_ids_it = get_it_logit(onnx_model, input_ids)
|
||||
it_output_text = tokenizer.batch_decode(output_ids_it[:, input_ids.shape[-1]:output_ids_it.shape[-1]])
|
||||
if args.rank == 0:
|
||||
for i in range(args.batchsize):
|
||||
print("prompt: ", prompts[i])
|
||||
print("answer: [it]", it_output_text[i])
|
||||
|
||||
##########################################################
|
||||
# validation with pytorch
|
||||
##########################################################
|
||||
"""
|
||||
generate_ids = torch_model.generate(inputs.input_ids, max_length=args.n_max_length)#, num_beams=4, do_sample=True)
|
||||
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"""
|
||||
if not args.speedup and not args.is_1st_graph:
|
||||
kvcache_torch = None
|
||||
output_ids_pt = torch.zeros(args.batchsize, args.n_max_length).int() # + input_ids.shape[-1] - 1).int()
|
||||
if args.fp16:
|
||||
torch_model = torch_model.half()
|
||||
|
||||
torch_model = torch_model.cuda()
|
||||
# print(torch.cuda.memory_summary())
|
||||
|
||||
avg_inference_time = 0
|
||||
with torch.no_grad():
|
||||
t0 = time.time()
|
||||
for i in range(args.n_max_length):
|
||||
input_id = input_ids[:,i] if i < input_ids.shape[1] else out_token
|
||||
input_id = input_id.view(args.batchsize,1).cuda()
|
||||
t00 = time.time()
|
||||
outputs = torch_model(input_id, past_key_values=kvcache_torch)
|
||||
t01 = time.time()
|
||||
avg_inference_time += (t01-t00)
|
||||
|
||||
logits = outputs['logits']
|
||||
kvcache_torch = outputs['past_key_values']
|
||||
out_token = torch.argmax(logits, dim=-1)
|
||||
output_ids_pt[:, i:i+1] = out_token
|
||||
t1 = time.time()
|
||||
avg_inference_time /= args.n_max_length
|
||||
result = "[pt] e2e time: {:.2f}s, average inference time: {:.2f}ms"\
|
||||
.format(t1-t0, avg_inference_time*1000)
|
||||
|
||||
if args.rank == 0:
|
||||
print(result)
|
||||
pt_output_text = tokenizer.batch_decode(output_ids_pt[:,input_ids.shape[-1]:args.n_max_length])
|
||||
for i in range(args.batchsize):
|
||||
print("[pt]", args.rank, pt_output_text[i])
|
||||
|
||||
if not args.is_1st_graph:
|
||||
assert(output_ids_it.shape[-1] == args.n_max_length)
|
||||
np.testing.assert_equal(output_ids_pt[:, input_ids.shape[-1]:args.n_max_length], output_ids_it[:,input_ids.shape[-1]:args.n_max_length])
|
Loading…
Reference in New Issue