LLaMA-Factory-310P3/mindie/examples/server/batch.py

157 lines
5.8 KiB
Python

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
from typing import List
import torch
from atb_llm.utils.log import logger
from .request import Request
class Batch:
req_ids: List[int]
req_list: List[Request]
batch_num: int
cu_seqlen_prefill: torch.Tensor
batch_input_ids: torch.Tensor
batch_position_ids: torch.Tensor
batch_block_tables: torch.Tensor
batch_slots_tables: torch.Tensor
batch_slot_indices: torch.Tensor
context_length: torch.Tensor
max_s: int
lm_head_indices: torch.Tensor
def __init__(self, req_list: List[Request]):
self.req_list = req_list
self.batch_num = len(req_list)
self.req_ids = [req.req_id for req in req_list]
input_ids_list = []
position_ids_list = []
slot_indices_list = []
context_length_list = []
self.max_s = 0
slot_offset = 0
for req in self.req_list:
context_length = req.input_ids.size(0)
input_ids_list.append(req.input_ids)
position_ids = torch.arange(context_length, dtype=torch.long)
position_ids_list.append(position_ids)
slot_indices = position_ids + slot_offset
slot_indices_list.append(slot_indices)
context_length_list.append(context_length)
self.max_s = max(self.max_s, context_length)
slot_offset += req.need_slots
self.cu_seqlen_prefill = torch.tensor([1])
self.batch_input_ids = torch.concat(input_ids_list, dim=0)
self.batch_position_ids = torch.concat(position_ids_list, dim=0)
self.batch_block_tables: None | torch.Tensor = None
self.batch_slots_tables: None | torch.Tensor = None
self.batch_slot_indices = torch.concat(slot_indices_list, dim=0)
self.context_length = torch.tensor(context_length_list, dtype=torch.int64)
self.lm_head_indices = torch.cumsum(self.context_length, dim=0) - 1
@classmethod
def concatenate(cls, batches: List["Batch"]):
req_ids = []
req_list = []
batch_num = 0
input_ids_list = [batch.batch_input_ids for batch in batches]
position_ids_list = [batch.batch_position_ids for batch in batches]
block_tables_list = []
slots_tables_list = [batch.batch_slots_tables for batch in batches]
slot_indices_list = []
context_length_list = [batch.context_length for batch in batches]
max_s = 0
max_block = 0
for batch in batches:
req_ids.extend(batch.req_ids)
req_list.extend(batch.req_list)
batch_num += batch.batch_num
max_s = max(max_s, batch.max_s)
max_block = max(max_block, batch.batch_block_tables.size(1))
slot_offset = 0
for batch in batches:
cur_block = batch.batch_block_tables.size(1)
if cur_block < max_block:
zero = torch.zeros(batch.batch_num, max_block - cur_block, dtype=torch.long)
batch.batch_block_tables = torch.concat([batch.batch_block_tables, zero], dim=-1)
block_tables_list.append(batch.batch_block_tables)
slot_indices_list.append(batch.batch_slot_indices + slot_offset)
slot_offset += batch.batch_slots_tables.size(0)
batches[0].req_ids = req_ids
batches[0].req_list = req_list
batches[0].batch_num = batch_num
batches[0].batch_input_ids = torch.concat(input_ids_list, dim=0)
batches[0].batch_position_ids = torch.concat(position_ids_list, dim=0)
batches[0].batch_block_tables = torch.concat(block_tables_list, dim=0)
batches[0].batch_slots_tables = torch.concat(slots_tables_list, dim=0)
batches[0].batch_slot_indices = torch.concat(slot_indices_list, dim=0)
batches[0].context_length = torch.concat(context_length_list, dim=0)
batches[0].max_s = max_s
while len(batches) > 1:
del batches[1]
def filter(self, postprocessor, cache_manager):
if self.batch_num == 0:
logger.error("batch.batch_num is 0")
raise AssertionError
finish_num = 0
finish_list = []
for i, req in enumerate(self.req_list):
if (postprocessor.stopping_criteria(req.out_token_list)) or \
len(req.out_token_list) >= postprocessor.max_new_tokens:
cache_manager.free(req)
finish_num += 1
finish_list.append(i)
if finish_num == 0:
return 0
batch_mask = torch.ones(self.batch_num, dtype=torch.int64)
batch_mask[finish_list] = 0
remain_batch = batch_mask.nonzero().flatten()
self.batch_num -= finish_num
if self.batch_num == 0:
return finish_num
self.batch_input_ids = self.batch_input_ids[remain_batch]
self.batch_position_ids = self.batch_position_ids[remain_batch]
self.batch_block_tables = self.batch_block_tables[remain_batch]
context_length = self.context_length[remain_batch]
self.max_s = int(context_length.max())
req_ids = []
req_list = []
slots_tables_list = []
slot_indices_list = []
slot_offset = 0
for i, req in enumerate(self.req_list):
if i in finish_list:
continue
req_ids.append(req.req_id)
req_list.append(req)
slots_tables_list.append(req.slot_tables)
slot_indices_list.append(int(self.context_length[i]) - 1 + slot_offset)
slot_offset += req.need_slots
self.req_ids = req_ids
self.req_list = req_list
self.batch_slots_tables = torch.concat(slots_tables_list, dim=0)
self.batch_slot_indices = torch.tensor(slot_indices_list, dtype=torch.long)
self.context_length = context_length
return finish_num