punk/inference.py

78 lines
2.6 KiB
Python
Executable File

from typing import List, Dict, Any
import json
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_responses(
prompts: List[str],
tokenizer: AutoTokenizer,
model: AutoModelForCausalLM
) -> List[str]:
"""处理一批文本并生成回复"""
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(
inputs['input_ids'],
max_length=inputs['input_ids'].shape[1] + 256,
pad_token_id=0
)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
def run_inference(
model_path: str,
input_file: str,
output_file: str
) -> None:
"""执行推理过程"""
device = "cuda"
batch_size = 8
# 初始化分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.add_eos_token = False
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True
).to(device, dtype=torch.bfloat16)
model.eval()
# 读取数据
with open(input_file, 'r') as f:
samples = [json.loads(line) for line in f]
# 批处理推理
for batch_start in tqdm(range(0, len(samples), batch_size)):
batch = samples[batch_start:batch_start + batch_size]
# 构建输入提示
prompts = []
for sample in batch:
if 'choices' in sample:
prompt = f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
else:
is_code_generation = "Write" in sample['question'].strip()
prompt = f"<用户>{sample['question']}<AI> def" if is_code_generation else f"{sample['question']}\n "
prompts.append(prompt)
# 生成回复
responses = generate_responses(prompts, tokenizer, model)
# 处理输出结果
for i, response in enumerate(responses):
result = response.split('<AI>')[1].strip()
samples[batch_start + i]['raw_outputs'] = [result]
# 保存结果
with open(output_file, 'w', encoding='utf-8') as f:
for sample in samples:
f.write(json.dumps(sample, ensure_ascii=False) + '\n')
if __name__ == '__main__':
MODEL_PATH = 'checkpoints/js/checkpoint-5800'
INPUT_FILE = 'test_set/军事知识问答_round4.jsonl'
OUTPUT_FILE = 'test_set/result_military.jsonl'
run_inference(MODEL_PATH, INPUT_FILE, OUTPUT_FILE)