232 lines
9.1 KiB
Python
232 lines
9.1 KiB
Python
|
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||
|
import argparse
|
||
|
import json
|
||
|
import os
|
||
|
import time
|
||
|
import torch
|
||
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||
|
|
||
|
from atb_llm.runner import ModelRunner
|
||
|
from atb_llm.utils.log import logger, print_log
|
||
|
from atb_llm.utils.file_utils import safe_open
|
||
|
|
||
|
|
||
|
def parse_arguments():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--model_path', help="model and tokenizer path")
|
||
|
parser.add_argument(
|
||
|
'--input_text',
|
||
|
type=str,
|
||
|
nargs='+',
|
||
|
default="What's deep learning?")
|
||
|
parser.add_argument(
|
||
|
'--input_file',
|
||
|
type=str,
|
||
|
help='CSV or Numpy file containing tokenized input. Alternative to text input.',
|
||
|
default=None)
|
||
|
parser.add_argument('--max_input_length', type=int, default=512)
|
||
|
parser.add_argument('--max_output_length', type=int, default=20)
|
||
|
parser.add_argument('--max_position_embeddings', type=int, default=None)
|
||
|
parser.add_argument("--batch_size", type=int, default=1)
|
||
|
|
||
|
parser.add_argument('--is_flash_causal_lm', action='store_true')
|
||
|
|
||
|
parser.add_argument('--num_beams',
|
||
|
type=int,
|
||
|
help="Use beam search if num_beams >1",
|
||
|
default=1)
|
||
|
parser.add_argument('--temperature', type=float, default=1.0)
|
||
|
parser.add_argument('--top_p', type=float, default=0.0)
|
||
|
parser.add_argument('--length_penalty', type=float, default=1.0)
|
||
|
parser.add_argument('--repetition_penalty', type=float, default=1.0)
|
||
|
|
||
|
parser.add_argument('--inputs_embeds_dir', type=str, default=None,
|
||
|
help='Directory of .pt files containing inputs_embeds.')
|
||
|
parser.add_argument('--min_length', type=int, default=10)
|
||
|
parser.add_argument('--stop_words_ids', type=json.loads, default=None)
|
||
|
parser.add_argument('--do_sample', type=bool, default=False)
|
||
|
parser.add_argument('--results_save_path', type=str, default=None,
|
||
|
help='File path to save inference results.')
|
||
|
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
class StoppingCriteriaSub(StoppingCriteria):
|
||
|
def __init__(self, stops=None):
|
||
|
super().__init__()
|
||
|
if stops is None:
|
||
|
stops = []
|
||
|
self.stops = stops
|
||
|
|
||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
||
|
for stop in self.stops:
|
||
|
if torch.all(torch.eq(input_ids[:, -len(stop):], stop)).item():
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
class FARunner:
|
||
|
def __init__(self, **kwargs):
|
||
|
self.rank = kwargs.get('rank', '0')
|
||
|
self.local_rank = kwargs.get('local_rank', self.rank)
|
||
|
self.world_size = kwargs.get('world_size', '1')
|
||
|
|
||
|
self.model_path = kwargs.get('model_path', None)
|
||
|
self.max_input_length = kwargs.get('max_input_length', None)
|
||
|
self.max_output_length = kwargs.get('max_output_length', None)
|
||
|
self.max_position_embeddings = kwargs.get('max_position_embeddings', None)
|
||
|
self.is_flash_causal_lm = kwargs.get('is_flash_causal_lm', False)
|
||
|
self.batch_size = kwargs.get('batch_size', None)
|
||
|
|
||
|
self.model = ModelRunner(
|
||
|
self.model_path, rank=self.rank, world_size=self.world_size,
|
||
|
local_rank=self.local_rank,
|
||
|
is_flash_causal_lm=self.is_flash_causal_lm,
|
||
|
max_position_embeddings=self.max_position_embeddings,
|
||
|
)
|
||
|
self.tokenizer = self.model.tokenizer
|
||
|
self.device = self.model.device
|
||
|
self.dtype = self.model.dtype
|
||
|
self.quantize = self.model.quantize
|
||
|
self.kv_quant = self.model.kv_quant
|
||
|
self.model.load_weights()
|
||
|
|
||
|
self.skip_word_embedding = False
|
||
|
if hasattr(self.model.model, 'skip_word_embedding'):
|
||
|
self.skip_word_embedding = self.model.model.skip_word_embedding
|
||
|
|
||
|
def warm_up(self):
|
||
|
print_log(self.rank, logger.info, "---------------begin warm_up---------------")
|
||
|
dummy_input_ids_full = torch.randint(
|
||
|
0, 32000, [self.batch_size, self.max_input_length], dtype=torch.long).npu()
|
||
|
self.model.generate(inputs=dummy_input_ids_full, do_sample=False, max_new_tokens=10)
|
||
|
print_log(self.rank, logger.info, "---------------end warm_up---------------")
|
||
|
|
||
|
def infer(self, input_text):
|
||
|
print_log(self.rank, logger.info, "---------------begin inference---------------")
|
||
|
if isinstance(input_text, str):
|
||
|
input_text = [input_text] * self.batch_size
|
||
|
|
||
|
inputs = self.tokenizer(input_text, return_tensors="pt", padding='max_length',
|
||
|
max_length=self.max_input_length,
|
||
|
truncation=True)
|
||
|
|
||
|
prefill_start_time = time.time()
|
||
|
with torch.no_grad():
|
||
|
self.model.generate(
|
||
|
inputs=inputs.input_ids.npu(),
|
||
|
attention_mask=inputs.attention_mask.npu(),
|
||
|
max_new_tokens=1
|
||
|
)
|
||
|
prefill_end_time = time.time()
|
||
|
|
||
|
decode_start_time = time.time()
|
||
|
with torch.no_grad():
|
||
|
generate_ids = self.model.generate(
|
||
|
inputs=inputs.input_ids.npu(),
|
||
|
attention_mask=inputs.attention_mask.npu(),
|
||
|
max_new_tokens=self.max_output_length
|
||
|
)
|
||
|
decode_end_time = time.time()
|
||
|
|
||
|
generate_text = self.tokenizer.batch_decode(
|
||
|
generate_ids[:, self.max_input_length:], skip_special_tokens=True,
|
||
|
clean_up_tokenization_spaces=False)
|
||
|
if self.rank == 0:
|
||
|
|
||
|
logger.info(f'{inputs.input_ids.shape=}')
|
||
|
|
||
|
input_tokens_num = len(inputs.input_ids[0])
|
||
|
generate_tokens_num = len(generate_ids[0]) - len(inputs.input_ids[0])
|
||
|
logger.info(f'Question: {input_text[0]}')
|
||
|
logger.info(f'Answer: {generate_text[0][:-generate_tokens_num]}')
|
||
|
logger.info(f'Input token num: {input_tokens_num}')
|
||
|
logger.info(f'Generate token num: {generate_tokens_num}')
|
||
|
|
||
|
logger.info("---------------end inference---------------")
|
||
|
|
||
|
prefill_time = (prefill_end_time - prefill_start_time)
|
||
|
e2e_time = (decode_end_time - decode_start_time)
|
||
|
try:
|
||
|
decode_average_time = (e2e_time - prefill_time) / (self.max_output_length - 1)
|
||
|
except ZeroDivisionError as e:
|
||
|
raise ZeroDivisionError from e
|
||
|
logger.info(
|
||
|
f"Prefill time: {prefill_time * 1000}ms, "
|
||
|
f"Decode average time: {decode_average_time * 1000}ms, "
|
||
|
f"E2E time: {e2e_time}s")
|
||
|
|
||
|
def infer_from_embeds(self, args):
|
||
|
if rank == 0:
|
||
|
logger.info("---------------begin inference---------------")
|
||
|
|
||
|
stop_words_ids = [torch.tensor(ids).npu() for ids in args.stop_words_ids]
|
||
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||
|
|
||
|
generation_args = {
|
||
|
"inputs_embeds": None,
|
||
|
'min_length': args.min_length,
|
||
|
'max_new_tokens': args.max_output_length,
|
||
|
"stopping_criteria": stopping_criteria,
|
||
|
'do_sample': args.do_sample,
|
||
|
'num_beams': args.num_beams,
|
||
|
'top_p': args.top_p,
|
||
|
'temperature': args.temperature,
|
||
|
'repetition_penalty': args.repetition_penalty,
|
||
|
'length_penalty': args.length_penalty,
|
||
|
}
|
||
|
|
||
|
image_answer_pairs = {}
|
||
|
for inputs_embeds_file_path in sorted([os.path.join(args.inputs_embeds_dir, _)
|
||
|
for _ in os.listdir(args.inputs_embeds_dir)]):
|
||
|
|
||
|
if not inputs_embeds_file_path.endswith(".pt"):
|
||
|
continue
|
||
|
|
||
|
if rank == 0:
|
||
|
logger.info(f'NO.{len(image_answer_pairs) + 1}')
|
||
|
logger.info(f'inputs_embeds_file_path: {inputs_embeds_file_path}')
|
||
|
|
||
|
inputs_embeds = torch.load(inputs_embeds_file_path).npu()
|
||
|
generation_args["inputs_embeds"] = inputs_embeds
|
||
|
|
||
|
with torch.no_grad():
|
||
|
generate_ids = self.model.generate(**generation_args)
|
||
|
|
||
|
output_text = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True)
|
||
|
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||
|
image_answer_pairs[inputs_embeds_file_path] = output_text
|
||
|
|
||
|
if rank == 0:
|
||
|
logger.info(f'Answer: {output_text}')
|
||
|
with safe_open(args.results_save_path, "w", encoding='utf-8') as f:
|
||
|
json.dump(image_answer_pairs, f)
|
||
|
logger.info('json dump finished')
|
||
|
|
||
|
if rank == 0:
|
||
|
logger.info("---------------end inference---------------")
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
arguments = parse_arguments()
|
||
|
|
||
|
rank = int(os.getenv("RANK", "0"))
|
||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||
|
input_dict = {
|
||
|
'rank': rank,
|
||
|
'world_size': world_size,
|
||
|
'local_rank': local_rank,
|
||
|
**vars(arguments)
|
||
|
}
|
||
|
|
||
|
fa_runner = FARunner(**input_dict)
|
||
|
|
||
|
if fa_runner.skip_word_embedding:
|
||
|
fa_runner.infer_from_embeds(arguments)
|
||
|
else:
|
||
|
fa_runner.warm_up()
|
||
|
fa_runner.infer(arguments.input_text)
|
||
|
|