LLaMA-Factory-310P3/mindie/examples/server/cache.py

154 lines
6.2 KiB
Python

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import os
import torch
from atb_llm.utils.log import logger
class CacheConfig:
def __init__(self, num_blocks=1024, block_size=128):
self.num_blocks = int(os.getenv("NUM_BLOCKS", f'{num_blocks}'))
self.block_size = int(os.getenv("BLOCK_SIZE", f'{block_size}'))
class ModelConfig:
def __init__(self, num_heads, num_kv_heads, head_size, num_layers, device, dtype, soc_info, kv_quant):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.num_layers = num_layers
self.device = device
self.dtype = dtype
self.soc_info = soc_info
self.kv_quant = kv_quant
def __repr__(self):
return (
"ModelConfig("
+ f"num_heads={self.num_heads}, "
+ f"num_kv_heads={self.num_kv_heads}, "
+ f"head_size={self.head_size}, "
+ f"num_layers={self.num_layers}, "
+ f"device={self.device}, "
+ f"dtype={self.dtype}, "
+ f"soc_info={self.soc_info}, "
+ f"kv_quant={self.kv_quant}, "
)
class CacheManager:
def __init__(self, cache_config, model_config):
self.block_size = cache_config.block_size
self.num_blocks = cache_config.num_blocks
self.num_heads = model_config.num_kv_heads
self.head_size = model_config.head_size
self.num_layers = model_config.num_layers
self.device = model_config.device
self.dtype = torch.int8 if model_config.kv_quant is not None else model_config.dtype
self.soc_info = model_config.soc_info
mem_need = self.num_blocks * self.block_size * self.num_heads * self.head_size * self.num_layers * 2 * \
self.get_dtype_size(self.dtype) / 1024 / 1024 / 1024
logger.info(f"kv cache will allocate {mem_need}GB memory")
if self.soc_info.need_nz:
self.kv_cache = [
(
torch.empty(
(self.num_blocks, self.num_heads * self.head_size // 16, self.block_size, 16),
dtype=self.dtype,
device=self.device,
),
torch.empty(
(self.num_blocks, self.num_heads * self.head_size // 16, self.block_size, 16),
dtype=self.dtype,
device=self.device,
),
)
for _ in range(self.num_layers)
]
else:
self.kv_cache = [
(
torch.empty(
(self.num_blocks, self.block_size, self.num_heads, self.head_size),
dtype=self.dtype,
device=self.device,
),
torch.empty(
(self.num_blocks, self.block_size, self.num_heads, self.head_size),
dtype=self.dtype,
device=self.device,
),
)
for _ in range(self.num_layers)
]
random_block_allocate = os.getenv("RANDOM_BLOCK_ALLOCATE", '0') == '1'
if random_block_allocate:
self.block_map = torch.randperm(self.num_blocks, dtype=torch.long)
self.contrary_block_map = torch.zeros(self.num_blocks, dtype=torch.long)
for i in range(self.num_blocks):
self.contrary_block_map[self.block_map[i]] = i
else:
self.block_map = torch.arange(self.num_blocks, dtype=torch.long)
self.contrary_block_map = torch.arange(self.num_blocks, dtype=torch.long)
self.free_block_mask = torch.ones(self.num_blocks, dtype=torch.long)
self.total_slots = torch.arange(self.num_blocks * self.block_size, dtype=torch.long)
self.total_slots = self.total_slots.view(self.num_blocks, self.block_size)
@staticmethod
def get_dtype_size(dtype):
dtype_size_map = {torch.float16: 2, torch.float32: 4, torch.bfloat16: 2, torch.int8: 1}
return dtype_size_map.get(dtype, 2)
def allocate(self, batch):
total_need_blocks = 0
max_need_blocks = 0
for req in batch.req_list:
if req.block_tables:
logger.error(f"req_id: {req.req_id} block has been allocated")
raise AssertionError
total_need_blocks += req.need_blocks
max_need_blocks = max(max_need_blocks, req.need_blocks)
free_block_indices = self.free_block_mask.nonzero().flatten()
if free_block_indices.numel() < total_need_blocks:
logger.error(f"Out of available cache blocks: asked {total_need_blocks}, "
f"only {free_block_indices.numel()} free blocks")
raise AssertionError
allocate_block_indices = free_block_indices[:total_need_blocks]
allocate_blocks = self.block_map[allocate_block_indices]
block_offset = 0
block_tables_list = []
slot_tables_list = []
for req in batch.req_list:
req.block_tables = allocate_blocks[block_offset:block_offset + req.need_blocks]
req.slot_tables = self.total_slots[req.block_tables].flatten()
block_tables = req.block_tables
if req.need_blocks < max_need_blocks:
block_tables = torch.concat(
[block_tables, torch.zeros(max_need_blocks - req.need_blocks, dtype=torch.long)], dim=0)
block_tables_list.append(block_tables.view(1, -1))
slot_tables_list.append(req.slot_tables)
block_offset += req.need_blocks
batch.batch_block_tables = torch.concat(block_tables_list, dim=0)
batch.batch_slots_tables = torch.concat(slot_tables_list, dim=0)
self.free_block_mask[allocate_block_indices] = 0
def free(self, req):
if req.block_tables is not None:
block_indices = self.contrary_block_map[req.block_tables]
self.free_block_mask[block_indices] = 1
def get_free_block_num(self):
free_block_indices = self.free_block_mask.nonzero()
return len(free_block_indices)