punk02/inference.py

74 lines
2.7 KiB
Python
Raw Permalink Normal View History

2024-11-12 10:40:53 +08:00
from typing import List, Dict, Any
import json
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_responses(model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
texts: List[str]) -> List[str]:
"""批量生成模型响应"""
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(inputs['input_ids'],
max_new_tokens=256,
pad_token_id=0)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
def format_prompt(sample: Dict[str, Any]) -> str:
"""根据输入样本格式化提示文本"""
if 'choices' in sample:
return f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
if sample['question'].strip().startswith("Write"):
return f"<用户>{sample['question']}<AI> def"
return f"{sample['question']}\n "
def process_response(response: str, question: str) -> str:
"""处理模型输出的响应"""
return response.split('<AI>')[1].strip()
def run_inference(model_path: str, input_path: str, output_path: str):
"""执行推理流程"""
# 初始化模型和分词器
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
model = model.to("cuda").to(dtype=torch.bfloat16)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.add_eos_token = False
# 读取数据
with open(input_path) as f:
samples = [json.loads(line) for line in f]
# 批量处理
batch_size = 8
for start_idx in tqdm(range(0, len(samples), batch_size)):
batch = samples[start_idx:start_idx + batch_size]
# 准备输入
prompts = [format_prompt(sample) for sample in batch]
# 生成响应
raw_outputs = generate_responses(model, tokenizer, prompts)
# 处理输出
for i, (sample, output) in enumerate(zip(batch, raw_outputs)):
processed_response = process_response(output, sample['question'])
samples[start_idx + i]['raw_outputs'] = [processed_response]
# 保存结果
with open(output_path, 'w', encoding='utf-8') as f:
for sample in samples:
json.dump(sample, f, ensure_ascii=False)
f.write('\n')
if __name__ == '__main__':
MODEL_PATH = 'checkpoints/py/checkpoint-15000'
INPUT_FILE = 'test_set/代码生成_round4.jsonl'
OUTPUT_FILE = 'test_set/code_generation_result.jsonl'
run_inference(MODEL_PATH, INPUT_FILE, OUTPUT_FILE)