sugar/inference.py

86 lines
2.9 KiB
Python
Raw Permalink Normal View History

2024-11-12 10:20:46 +08:00
from typing import List, Dict, Any
import json
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
# 常量定义
MAX_NEW_TOKENS = 256
BATCH_SIZE = 16
DEVICE = "cuda"
DTYPE = torch.bfloat16
def load_model_and_tokenizer(checkpoint_path: str):
"""加载模型和分词器"""
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
trust_remote_code=True
).to(DEVICE, dtype=DTYPE)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.add_eos_token = False
return model, tokenizer
def format_prompt(sample: Dict[str, Any]) -> str:
"""格式化输入提示"""
if 'choices' in sample:
return f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
if sample['question'].strip().startswith("Write"):
return f"<用户>{sample['question']}<AI> def"
return f"{sample['question']}\n "
def process_response(response_text: str, question: str) -> str:
"""处理模型输出"""
return response_text.split('<AI>')[1].strip()
def generate_responses(model, tokenizer, text_inputs: List[str]) -> List[str]:
"""生成模型回复"""
encoded = tokenizer(text_inputs, return_tensors="pt", padding=True).to(model.device)
generated = model.generate(
encoded['input_ids'],
max_new_tokens=MAX_NEW_TOKENS,
pad_token_id=0
)
return tokenizer.batch_decode(generated, skip_special_tokens=True)
def execute_model(checkpoint_dir: str, src_file: str, dest_file: str) -> None:
"""主执行函数"""
# 加载模型和分词器
model, tokenizer = load_model_and_tokenizer(checkpoint_dir)
# 读取输入数据
with open(src_file, 'r') as fin:
samples = [json.loads(line) for line in fin]
# 批量处理数据
for start_idx in tqdm(range(0, len(samples), BATCH_SIZE)):
batch = samples[start_idx:start_idx + BATCH_SIZE]
# 准备输入
prompts = [format_prompt(sample) for sample in batch]
# 生成回复
outputs = generate_responses(model, tokenizer, prompts)
# 处理输出
for i, (output, sample) in enumerate(zip(outputs, batch)):
response = process_response(output, sample['question'])
samples[start_idx + i]['raw_outputs'] = [response]
# 保存结果
with open(dest_file, 'w', encoding='utf-8') as fout:
for sample in samples:
json.dump(sample, fout, ensure_ascii=False)
fout.write('\n')
if __name__ == '__main__':
model_path = 'checkpoints/py/checkpoint-7000'
input_file = 'test_set/代码生成_round4.jsonl'
output_file = 'test_set/代码生成_round4_result.jsonl'
execute_model(model_path, input_file, output_file)