LLaMA-Factory-310P3/mindie/examples/server/request.py

111 lines
3.5 KiB
Python

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import math
from typing import List
from dataclasses import dataclass
import torch
class Request:
req_id: int
input_ids: torch.Tensor
input_length: int
need_blocks: int
need_slots: int
block_tables: torch.Tensor
slot_tables: torch.Tensor
out_token_list: List[int]
def __init__(self, max_out_length: int, block_size: int, req_id: int, input_ids: torch.Tensor):
self.req_id = req_id
self.input_ids = input_ids.flatten()
self.input_length = self.input_ids.numel()
try:
self.need_blocks = math.ceil((self.input_length + max_out_length) / block_size)
except ZeroDivisionError as e:
raise ZeroDivisionError from e
self.need_slots = self.need_blocks * block_size
self.block_tables: None | torch.Tensor = None
self.slot_tables: None | torch.Tensor = None
self.out_token_list = []
class MultiModalRequest():
def __init__(self, max_out_length: int, block_size: int, req_id: int, input_ids: torch.Tensor):
self.req_id = req_id
self.input_ids = input_ids
self.input_length = self.input_ids.shape[0]
try:
self.need_blocks = math.ceil((self.input_length + max_out_length) / block_size)
except ZeroDivisionError as e:
raise ZeroDivisionError from e
self.need_slots = self.need_blocks * block_size
self.block_tables = None
self.slot_tables = None
self.out_token_list = []
def request_from_token(input_ids, max_out_length, block_size, req_idx=0) -> Request:
input_ids = torch.tensor(input_ids, dtype=torch.int64)
request = Request(max_out_length, block_size, req_idx, input_ids)
return request
def request_from_text(text, tokenizer, max_out_length, block_size, req_idx=0) -> Request:
input_ids = tokenizer([text], return_tensors="pt")["input_ids"].flatten()
request = request_from_token(input_ids, max_out_length, block_size, req_idx)
return request
@dataclass
class MultiModalRequestParams:
text:str
image:str
max_out_length:int
block_size:int
req_idx:int
def request_from_text_and_image(processor, model, multimodalparams):
text = multimodalparams.text
image = multimodalparams.image
max_out_length = multimodalparams.max_out_length
block_size = multimodalparams.block_size
req_idx = multimodalparams.text
inputs_embeds = model.model.prepare_prefill_token(text, image, processor)
request = MultiModalRequest(max_out_length, block_size, req_idx, inputs_embeds)
return request
def request_from_token_file(input_path, max_out_length, block_size) -> List[Request]:
req_list = []
req_idx = 0
with open(input_path, 'r') as f:
for line in f.readlines():
token_str_list = line.split(',')
input_ids = []
for token_str in token_str_list:
input_ids.append(int(token_str))
req_list.append(request_from_token(input_ids, max_out_length, block_size, req_idx))
req_idx += 1
return req_list
def request_from_text_file(input_path, tokenizer, max_out_length, block_size) -> List[Request]:
req_list = []
req_idx = 0
with open(input_path, 'r') as f:
for line in f.readlines():
if line[-1] != '\n':
continue
text = line[:-1]
req_list.append(request_from_text(text, tokenizer, max_out_length, block_size, req_idx=0))
req_idx += 1
return req_list