LLaMA-Factory-310P3/mindie/examples/run_pa.py

364 lines
17 KiB
Python
Raw Normal View History

2024-09-10 15:38:33 +08:00
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import argparse
import copy
import json
import math
import os
import time
import torch
import torch_npu
from atb_llm.runner import ModelRunner
from atb_llm.utils.cpu_binding import NpuHbmInfo
from atb_llm.utils.env import ENV
from atb_llm.utils.log import logger, print_log
from atb_llm.utils.file_utils import safe_open
from examples.server.cache import CacheConfig, ModelConfig, CacheManager
from examples.server.generate import decode_token, generate_req
from examples.server.request import request_from_token
class PARunner:
def __init__(self, **kwargs):
self.rank = kwargs.get('rank', '0')
self.local_rank = kwargs.get('local_rank', self.rank)
self.world_size = kwargs.get('world_size', '1')
self.model_path = kwargs.get('model_path', None)
self.input_text = kwargs.get('input_text', None)
self.max_batch_size = kwargs.get('max_batch_size', None)
self.max_input_length = kwargs.get('max_input_length', None)
self.max_output_length = kwargs.get('max_output_length', None)
self.max_position_embeddings = kwargs.get('max_position_embeddings', None)
self.max_prefill_tokens = kwargs.get('max_prefill_tokens', None)
self.block_size = kwargs.get('block_size', None)
self.chat_template = kwargs.get('chat_template', None)
self.is_flash_model = kwargs.get('is_flash_model', None)
self.model = ModelRunner(
self.model_path, rank=self.rank, world_size=self.world_size,
local_rank=self.local_rank,
max_position_embeddings=self.max_position_embeddings
)
self.tokenizer = self.model.tokenizer
if self.chat_template:
self.tokenizer.chat_template = self._load_chat_template(self.chat_template)
self.dtype = self.model.dtype
self.quantize = self.model.quantize
self.kv_quant = self.model.kv_quant
self.model.load_weights()
self.device = self.model.device
self.model_config = ModelConfig(self.model.num_heads,
self.model.num_kv_heads,
self.model.head_size,
self.model.num_layers,
self.model.device,
self.model.dtype,
self.model.soc_info,
self.kv_quant)
self.max_memory = NpuHbmInfo.get_hbm_capacity(self.local_rank, self.world_size, self.model.soc_info.need_nz)
self.init_memory = int(
self.max_memory * NpuHbmInfo.get_hbm_usage(self.local_rank, self.world_size, self.model.soc_info.need_nz))
print_log(self.rank, logger.info, f'hbm_capacity(GB): {self.max_memory / (1024 ** 3)}, '
f'init_memory(GB): {self.init_memory / (1024 ** 3)}')
self.warm_up_memory = 0
self.warm_up_num_blocks = 0
self.cache_manager = None
def __repr__(self):
return (
"PARunner("
+ f"model_path={self.model_path}, "
+ f"input_text={self.input_text}, "
+ f"max_position_embeddings={self.max_position_embeddings}, "
+ f"max_input_length={self.max_input_length}, "
+ f"max_output_length={self.max_output_length}, "
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
+ f"is_flash_model={self.is_flash_model}, "
+ f"max_batch_size={self.max_batch_size}, "
+ f"dtype={self.dtype}, "
+ f"block_size={self.block_size}, "
+ f"model_config={self.model_config}, "
+ f"max_memory={self.max_memory}, "
)
@staticmethod
def _load_chat_template(chat_template: str):
if os.path.exists(chat_template):
with open(chat_template, "r", encoding="utf-8") as f:
chat_template_content = f.read()
else:
chat_template_content = chat_template
if chat_template_content:
print_log(int(os.getenv("RANK", "0")), logger.info, f"Using chat template:\n{chat_template_content}")
return chat_template_content
def warm_up(self):
if self.max_prefill_tokens == -1:
self.max_prefill_tokens = self.max_batch_size * (self.max_input_length + self.max_output_length)
all_input_length = self.max_batch_size * self.max_input_length
input_ids = torch.ones(all_input_length, dtype=torch.int64).to(self.device)
position_ids = torch.arange(self.max_input_length, dtype=torch.int32).repeat(self.max_batch_size).to(
self.device)
cu_seqlen_prefill = torch.tensor([1])
try:
block_num = math.ceil(all_input_length / self.block_size)
except ZeroDivisionError as e:
raise ZeroDivisionError from e
block_tables_tensor = torch.arange(block_num, dtype=torch.int32).view(1, -1).to(self.device)
slots = torch.arange(all_input_length, dtype=torch.int32).to(self.device)
input_lengths_tensor = torch.tensor(
[self.max_input_length] * self.max_batch_size, dtype=torch.int64
).to(self.device)
prefill_head_indices = torch.tensor([all_input_length - 1], dtype=torch.int64).to(self.device)
print_log(self.rank, logger.info, "---------------begin warm_up---------------")
try:
self.warm_up_num_blocks = math.ceil((self.max_input_length + self.max_output_length) /
self.block_size) * self.max_batch_size
except ZeroDivisionError as e:
raise ZeroDivisionError from e
cache_config = CacheConfig(self.warm_up_num_blocks, self.block_size)
self.cache_manager = CacheManager(cache_config, self.model_config)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
is_prefill=cu_seqlen_prefill is not None,
block_tables=block_tables_tensor,
kv_cache=self.cache_manager.kv_cache,
slots=slots,
input_lengths=input_lengths_tensor,
max_seq_len=self.max_input_length,
lm_head_indices=prefill_head_indices
)
self.warm_up_memory = int(
self.max_memory * NpuHbmInfo.get_hbm_usage(self.local_rank, self.world_size, self.model.soc_info.need_nz))
print_log(self.rank, logger.info, f'warmup_memory(GB): {self.warm_up_memory / (1024 ** 3): .2f}')
print_log(self.rank, logger.info, "---------------end warm_up---------------")
def infer(self, inputs, batch_size, max_output_length, ignore_eos, is_chat_model=False, **kwargs):
print_log(self.rank, logger.info, "---------------begin inference---------------")
if ignore_eos:
self.model.postprocessor.eos_token_id = []
is_truncation = kwargs.get("truncation", False)
input_ids = self._build_model_inputs(inputs, is_chat_model, is_truncation)
if len(input_ids) == 1:
req_list = [request_from_token(input_ids[0], max_output_length, self.block_size, req_idx=idx)
for idx in range(batch_size)]
else:
req_list = [request_from_token(input_ids_ins, max_output_length, self.block_size, req_idx=idx)
for idx, input_ids_ins in enumerate(input_ids)]
print_log(self.rank, logger.debug, f'req_list[0].input_ids: {req_list[0].input_ids}')
if not self.cache_manager:
if self.max_prefill_tokens == -1:
self.max_prefill_tokens = self.max_batch_size * (self.max_input_length + self.max_output_length)
cache_block_size = self.block_size * self.model.num_kv_heads * self.model.head_size
dtype_size = CacheManager.get_dtype_size(self.dtype)
total_cache_size = self.model.num_layers * cache_block_size * 2 * dtype_size
max_memory = ENV.memory_fraction * self.max_memory \
if not ENV.max_memory_gb else int(ENV.max_memory_gb) * (1 << 30)
free_memory = max_memory - ENV.reserved_memory_gb * (1 << 30) - (
self.warm_up_memory if self.warm_up_memory != 0 else self.init_memory)
print_log(self.rank, logger.info,
f"infer max_memory(GB): {max_memory / (1024 ** 3): .2f}, "
f"warm_up_memory(GB): {self.warm_up_memory / (1024 ** 3): .2f}, "
f"free_memory(GB): {free_memory / (1024 ** 3): .2f}")
num_blocks = int(free_memory // total_cache_size)
print_log(self.rank, logger.info, f"num_blocks: {num_blocks}, free_memory: {free_memory}")
cache_config = CacheConfig(num_blocks, self.block_size)
self.cache_manager = CacheManager(cache_config, self.model_config)
if ENV.benchmark_enable:
req_list_dummy = copy.deepcopy(req_list)
self.model.postprocessor.max_new_tokens = 2
generate_req(req_list_dummy, self.model, self.max_batch_size, self.max_prefill_tokens, self.cache_manager)
self.model.postprocessor.max_new_tokens = max_output_length
skip_special_tokens = kwargs.get("skip_special_tokens", False)
if not ENV.profiling_enable:
print_log(self.rank, logger.debug, "no profiling")
torch.npu.synchronize()
e2e_start = time.time()
generate_req(req_list, self.model, self.max_batch_size, self.max_prefill_tokens, self.cache_manager)
_, _ = decode_token(req_list, self.tokenizer, skip_special_tokens)
torch.npu.synchronize()
e2e_end = time.time()
e2e_time = e2e_end - e2e_start
else:
print_log(self.rank, logger.debug, "enter profiling")
profiling_path = ENV.profiling_filepath
if not os.path.exists(profiling_path):
os.makedirs(profiling_path, exist_ok=True)
profiler_level = torch_npu.profiler.ProfilerLevel
target_level = "Level" + ENV.profiling_level
if not hasattr(profiler_level, target_level):
raise NotImplementedError(f"target_level: {target_level} is not implemented"
f" in torch_npu.profiler.ProfilerLevel")
actual_profiler_level = getattr(profiler_level, target_level)
torch.npu.synchronize()
e2e_start = time.time()
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=actual_profiler_level,
l2_cache=False,
data_simplification=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profiling_path),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config):
generate_req(req_list, self.model, self.max_batch_size, self.max_prefill_tokens, self.cache_manager)
torch.npu.synchronize()
e2e_end = time.time()
e2e_time = e2e_end - e2e_start
generate_text_list, token_num_list = decode_token(req_list, self.tokenizer, skip_special_tokens)
if ENV.token_ids_save_enable:
if self.local_rank == 0:
for idx, req in enumerate(req_list):
input_ids_save_filename = f"input_ids_{idx}.pth"
output_ids_save_filename = f"output_ids_{idx}.txt"
torch.save(req.input_ids.cpu(),
os.path.join(ENV.token_ids_save_folder, input_ids_save_filename))
output_path = os.path.join(ENV.token_ids_save_folder, output_ids_save_filename)
with safe_open(output_path, "w", encoding='utf-8') as f:
f.write(' '.join(map(str, req.out_token_list)))
print_log(self.rank, logger.info, "---------------end inference---------------")
return generate_text_list, token_num_list, e2e_time
def _build_model_inputs(self, inputs, is_chat_model, is_truncation=False):
input_texts, input_ids, input_conversations = [], [], []
if isinstance(inputs, list) and inputs:
if isinstance(inputs[0], str):
input_texts = inputs
elif isinstance(inputs[0], torch.Tensor):
input_ids = inputs
elif isinstance(inputs[0], list) and inputs[0]:
if isinstance(inputs[0][0], int):
input_ids = inputs
elif isinstance(inputs[0][0], dict):
input_conversations = inputs
if not (input_texts or input_ids or input_conversations):
raise ValueError(f"The inputs of `PARunner.infer` must be as List[str], List[torch.Tensor], List[List[int]]"
f" or List[List[Dict]]. Now the inputs ({inputs}) is not acceptable or is empty.")
if is_chat_model:
if input_conversations:
input_ids = self.model.build_inputs(input_conversations)
elif input_texts:
input_conversations = [[{"role": "user", "content": t}] for t in input_texts]
input_ids = self.model.build_inputs(input_conversations)
else:
print_log(self.rank, logger.warning, "Neither conversations nor input_texts exist, "
"'chat' parameter is not effective.")
elif input_texts:
input_ids = [self.tokenizer([text], return_tensors="pt", truncation=is_truncation)["input_ids"].flatten()
for text in input_texts]
return input_ids
def cmd_bool(cmd_arg):
if cmd_arg == "True":
return True
elif cmd_arg == "False":
return False
raise ValueError(f"{cmd_arg} should be a boolean")
def parse_ids(list_str):
return [int(item) for item in list_str.split(',')]
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', help="model and tokenizer path")
parser.add_argument(
'--input_texts',
type=str,
nargs='+',
default=["What's deep learning?"])
parser.add_argument(
'--input_ids',
type=parse_ids,
nargs='+',
default=None)
parser.add_argument(
'--input_file',
type=str,
help='CSV or Numpy file containing tokenized input. Alternative to text input.',
default=None)
parser.add_argument("--max_batch_size", type=int, default=1)
parser.add_argument('--max_input_length', type=int, default=1024)
parser.add_argument('--max_output_length', type=int, default=20)
parser.add_argument('--max_position_embeddings', type=int, default=None)
parser.add_argument('--max_prefill_tokens', type=int, default=-1)
parser.add_argument("--block_size", type=int, default=128)
parser.add_argument('--chat_template', type=str, default=None)
parser.add_argument('--ignore_eos', action='store_true')
parser.add_argument('--is_chat_model', action='store_true')
parser.add_argument('--is_flash_model', action='store_false')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
input_dict = {
'rank': rank,
'world_size': world_size,
'local_rank': local_rank,
**vars(args)
}
if args.input_ids:
infer_inputs = args.input_ids
else:
infer_inputs = args.input_texts
if args.is_chat_model and args.input_file:
conversations = []
with open(args.input_file, 'r', encoding='utf-8') as file:
for line in file:
data_line = json.loads(line)
conversations.append(data_line)
infer_inputs = conversations
pa_runner = PARunner(**input_dict)
print_log(rank, logger.info, f'pa_runner: {pa_runner}')
pa_runner.warm_up()
infer_params = {
"inputs": infer_inputs,
"batch_size": args.max_batch_size,
"max_output_length": args.max_output_length,
"ignore_eos": args.ignore_eos,
"is_chat_model": args.is_chat_model
}
generate_texts, token_nums, _ = pa_runner.infer(**infer_params)
length = len(infer_inputs)
for i, generate_text in enumerate(generate_texts):
if i < length:
print_log(rank, logger.info, f'Question[{i}]: {infer_inputs[i]}')
print_log(rank, logger.info, f'Answer[{i}]: {generate_text}')
print_log(rank, logger.info, f'Generate[{i}] token num: {token_nums[i]}')