LLaMA-Factory-310P3/mindie/examples/run_fa.py

232 lines
9.1 KiB
Python
Raw Normal View History

2024-09-10 15:38:33 +08:00
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import argparse
import json
import os
import time
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from atb_llm.runner import ModelRunner
from atb_llm.utils.log import logger, print_log
from atb_llm.utils.file_utils import safe_open
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', help="model and tokenizer path")
parser.add_argument(
'--input_text',
type=str,
nargs='+',
default="What's deep learning?")
parser.add_argument(
'--input_file',
type=str,
help='CSV or Numpy file containing tokenized input. Alternative to text input.',
default=None)
parser.add_argument('--max_input_length', type=int, default=512)
parser.add_argument('--max_output_length', type=int, default=20)
parser.add_argument('--max_position_embeddings', type=int, default=None)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument('--is_flash_causal_lm', action='store_true')
parser.add_argument('--num_beams',
type=int,
help="Use beam search if num_beams >1",
default=1)
parser.add_argument('--temperature', type=float, default=1.0)
parser.add_argument('--top_p', type=float, default=0.0)
parser.add_argument('--length_penalty', type=float, default=1.0)
parser.add_argument('--repetition_penalty', type=float, default=1.0)
parser.add_argument('--inputs_embeds_dir', type=str, default=None,
help='Directory of .pt files containing inputs_embeds.')
parser.add_argument('--min_length', type=int, default=10)
parser.add_argument('--stop_words_ids', type=json.loads, default=None)
parser.add_argument('--do_sample', type=bool, default=False)
parser.add_argument('--results_save_path', type=str, default=None,
help='File path to save inference results.')
return parser.parse_args()
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=None):
super().__init__()
if stops is None:
stops = []
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
for stop in self.stops:
if torch.all(torch.eq(input_ids[:, -len(stop):], stop)).item():
return True
return False
class FARunner:
def __init__(self, **kwargs):
self.rank = kwargs.get('rank', '0')
self.local_rank = kwargs.get('local_rank', self.rank)
self.world_size = kwargs.get('world_size', '1')
self.model_path = kwargs.get('model_path', None)
self.max_input_length = kwargs.get('max_input_length', None)
self.max_output_length = kwargs.get('max_output_length', None)
self.max_position_embeddings = kwargs.get('max_position_embeddings', None)
self.is_flash_causal_lm = kwargs.get('is_flash_causal_lm', False)
self.batch_size = kwargs.get('batch_size', None)
self.model = ModelRunner(
self.model_path, rank=self.rank, world_size=self.world_size,
local_rank=self.local_rank,
is_flash_causal_lm=self.is_flash_causal_lm,
max_position_embeddings=self.max_position_embeddings,
)
self.tokenizer = self.model.tokenizer
self.device = self.model.device
self.dtype = self.model.dtype
self.quantize = self.model.quantize
self.kv_quant = self.model.kv_quant
self.model.load_weights()
self.skip_word_embedding = False
if hasattr(self.model.model, 'skip_word_embedding'):
self.skip_word_embedding = self.model.model.skip_word_embedding
def warm_up(self):
print_log(self.rank, logger.info, "---------------begin warm_up---------------")
dummy_input_ids_full = torch.randint(
0, 32000, [self.batch_size, self.max_input_length], dtype=torch.long).npu()
self.model.generate(inputs=dummy_input_ids_full, do_sample=False, max_new_tokens=10)
print_log(self.rank, logger.info, "---------------end warm_up---------------")
def infer(self, input_text):
print_log(self.rank, logger.info, "---------------begin inference---------------")
if isinstance(input_text, str):
input_text = [input_text] * self.batch_size
inputs = self.tokenizer(input_text, return_tensors="pt", padding='max_length',
max_length=self.max_input_length,
truncation=True)
prefill_start_time = time.time()
with torch.no_grad():
self.model.generate(
inputs=inputs.input_ids.npu(),
attention_mask=inputs.attention_mask.npu(),
max_new_tokens=1
)
prefill_end_time = time.time()
decode_start_time = time.time()
with torch.no_grad():
generate_ids = self.model.generate(
inputs=inputs.input_ids.npu(),
attention_mask=inputs.attention_mask.npu(),
max_new_tokens=self.max_output_length
)
decode_end_time = time.time()
generate_text = self.tokenizer.batch_decode(
generate_ids[:, self.max_input_length:], skip_special_tokens=True,
clean_up_tokenization_spaces=False)
if self.rank == 0:
logger.info(f'{inputs.input_ids.shape=}')
input_tokens_num = len(inputs.input_ids[0])
generate_tokens_num = len(generate_ids[0]) - len(inputs.input_ids[0])
logger.info(f'Question: {input_text[0]}')
logger.info(f'Answer: {generate_text[0][:-generate_tokens_num]}')
logger.info(f'Input token num: {input_tokens_num}')
logger.info(f'Generate token num: {generate_tokens_num}')
logger.info("---------------end inference---------------")
prefill_time = (prefill_end_time - prefill_start_time)
e2e_time = (decode_end_time - decode_start_time)
try:
decode_average_time = (e2e_time - prefill_time) / (self.max_output_length - 1)
except ZeroDivisionError as e:
raise ZeroDivisionError from e
logger.info(
f"Prefill time: {prefill_time * 1000}ms, "
f"Decode average time: {decode_average_time * 1000}ms, "
f"E2E time: {e2e_time}s")
def infer_from_embeds(self, args):
if rank == 0:
logger.info("---------------begin inference---------------")
stop_words_ids = [torch.tensor(ids).npu() for ids in args.stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
generation_args = {
"inputs_embeds": None,
'min_length': args.min_length,
'max_new_tokens': args.max_output_length,
"stopping_criteria": stopping_criteria,
'do_sample': args.do_sample,
'num_beams': args.num_beams,
'top_p': args.top_p,
'temperature': args.temperature,
'repetition_penalty': args.repetition_penalty,
'length_penalty': args.length_penalty,
}
image_answer_pairs = {}
for inputs_embeds_file_path in sorted([os.path.join(args.inputs_embeds_dir, _)
for _ in os.listdir(args.inputs_embeds_dir)]):
if not inputs_embeds_file_path.endswith(".pt"):
continue
if rank == 0:
logger.info(f'NO.{len(image_answer_pairs) + 1}')
logger.info(f'inputs_embeds_file_path: {inputs_embeds_file_path}')
inputs_embeds = torch.load(inputs_embeds_file_path).npu()
generation_args["inputs_embeds"] = inputs_embeds
with torch.no_grad():
generate_ids = self.model.generate(**generation_args)
output_text = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
image_answer_pairs[inputs_embeds_file_path] = output_text
if rank == 0:
logger.info(f'Answer: {output_text}')
with safe_open(args.results_save_path, "w", encoding='utf-8') as f:
json.dump(image_answer_pairs, f)
logger.info('json dump finished')
if rank == 0:
logger.info("---------------end inference---------------")
if __name__ == '__main__':
arguments = parse_arguments()
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
input_dict = {
'rank': rank,
'world_size': world_size,
'local_rank': local_rank,
**vars(arguments)
}
fa_runner = FARunner(**input_dict)
if fa_runner.skip_word_embedding:
fa_runner.infer_from_embeds(arguments)
else:
fa_runner.warm_up()
fa_runner.infer(arguments.input_text)