vc/inference.py

70 lines
2.6 KiB
Python
Raw Permalink Normal View History

2024-11-12 00:41:13 +08:00
import json
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def execute_model_prediction(model_path: str, input_file: str, output_file: str):
# ... existing imports ...
# 初始化模型和分词器
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True
).to("cuda", dtype=torch.bfloat16)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.add_eos_token = False
def generate_responses(prompts: list[str]) -> list[str]:
"""处理一批文本并生成回复"""
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(inputs['input_ids'], max_new_tokens=256, pad_token_id=0)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
# 读取数据
samples = []
with open(input_file, 'r') as f:
for line in f:
samples.append(json.loads(line))
BATCH_SIZE = 16
for batch_start in tqdm(range(0, len(samples), BATCH_SIZE)):
batch = samples[batch_start:batch_start + BATCH_SIZE]
# 构建输入提示
prompts = []
for sample in batch:
if 'choices' in sample:
prompt = f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
else:
prefix = "<用户>" if sample['question'].strip().startswith("Write") else ""
suffix = "<AI> def" if sample['question'].strip().startswith("Write") else "\n "
prompt = f"{prefix}{sample['question']}{suffix}"
prompts.append(prompt)
# 生成回复并处理
responses = generate_responses(prompts)
processed_outputs = []
for idx, response in enumerate(responses):
output = response.split('<AI>')[1].strip()
processed_outputs.append(output)
# 更新结果
for idx, output in enumerate(processed_outputs):
samples[batch_start + idx]['raw_outputs'] = [output]
# 保存结果
with open(output_file, 'w', encoding='utf-8') as f:
for sample in samples:
json.dump(sample, f, ensure_ascii=False)
f.write('\n')
if __name__ == '__main__':
MODEL_PATH = 'checkpoints/js/checkpoint-7200'
INPUT_FILE = 'test_set/军事知识问答_round4.jsonl'
OUTPUT_FILE = 'test_set/军事知识问答_round4_output.jsonl'
execute_model_prediction(MODEL_PATH, INPUT_FILE, OUTPUT_FILE)