78 lines
2.6 KiB
Python
Executable File
78 lines
2.6 KiB
Python
Executable File
from typing import List, Dict, Any
|
|
import json
|
|
import torch
|
|
from tqdm.auto import tqdm
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
def generate_responses(
|
|
prompts: List[str],
|
|
tokenizer: AutoTokenizer,
|
|
model: AutoModelForCausalLM
|
|
) -> List[str]:
|
|
"""处理一批文本并生成回复"""
|
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
|
outputs = model.generate(
|
|
inputs['input_ids'],
|
|
max_length=inputs['input_ids'].shape[1] + 256,
|
|
pad_token_id=0
|
|
)
|
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
def run_inference(
|
|
model_path: str,
|
|
input_file: str,
|
|
output_file: str
|
|
) -> None:
|
|
"""执行推理过程"""
|
|
device = "cuda"
|
|
batch_size = 8
|
|
|
|
# 初始化分词器
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
tokenizer.pad_token = tokenizer.unk_token
|
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
|
tokenizer.add_eos_token = False
|
|
|
|
# 加载模型
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
trust_remote_code=True
|
|
).to(device, dtype=torch.bfloat16)
|
|
model.eval()
|
|
|
|
# 读取数据
|
|
with open(input_file, 'r') as f:
|
|
samples = [json.loads(line) for line in f]
|
|
|
|
# 批处理推理
|
|
for batch_start in tqdm(range(0, len(samples), batch_size)):
|
|
batch = samples[batch_start:batch_start + batch_size]
|
|
|
|
# 构建输入提示
|
|
prompts = []
|
|
for sample in batch:
|
|
if 'choices' in sample:
|
|
prompt = f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
|
|
else:
|
|
is_code_generation = "Write" in sample['question'].strip()
|
|
prompt = f"<用户>{sample['question']}<AI> def" if is_code_generation else f"{sample['question']}\n "
|
|
prompts.append(prompt)
|
|
|
|
# 生成回复
|
|
responses = generate_responses(prompts, tokenizer, model)
|
|
|
|
# 处理输出结果
|
|
for i, response in enumerate(responses):
|
|
result = response.split('<AI>')[1].strip()
|
|
samples[batch_start + i]['raw_outputs'] = [result]
|
|
|
|
# 保存结果
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
for sample in samples:
|
|
f.write(json.dumps(sample, ensure_ascii=False) + '\n')
|
|
|
|
if __name__ == '__main__':
|
|
MODEL_PATH = 'checkpoints/js/checkpoint-5800'
|
|
INPUT_FILE = 'test_set/军事知识问答_round4.jsonl'
|
|
OUTPUT_FILE = 'test_set/result_military.jsonl'
|
|
run_inference(MODEL_PATH, INPUT_FILE, OUTPUT_FILE) |