70 lines
2.6 KiB
Python
70 lines
2.6 KiB
Python
|
import json
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
||
|
def execute_model_prediction(model_path: str, input_file: str, output_file: str):
|
||
|
# ... existing imports ...
|
||
|
|
||
|
# 初始化模型和分词器
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
model_path,
|
||
|
trust_remote_code=True
|
||
|
).to("cuda", dtype=torch.bfloat16)
|
||
|
model.eval()
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
tokenizer.pad_token = tokenizer.unk_token
|
||
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||
|
tokenizer.add_eos_token = False
|
||
|
|
||
|
def generate_responses(prompts: list[str]) -> list[str]:
|
||
|
"""处理一批文本并生成回复"""
|
||
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||
|
outputs = model.generate(inputs['input_ids'], max_new_tokens=256, pad_token_id=0)
|
||
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||
|
|
||
|
# 读取数据
|
||
|
samples = []
|
||
|
with open(input_file, 'r') as f:
|
||
|
for line in f:
|
||
|
samples.append(json.loads(line))
|
||
|
|
||
|
BATCH_SIZE = 16
|
||
|
for batch_start in tqdm(range(0, len(samples), BATCH_SIZE)):
|
||
|
batch = samples[batch_start:batch_start + BATCH_SIZE]
|
||
|
|
||
|
# 构建输入提示
|
||
|
prompts = []
|
||
|
for sample in batch:
|
||
|
if 'choices' in sample:
|
||
|
prompt = f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
|
||
|
else:
|
||
|
prefix = "<用户>" if sample['question'].strip().startswith("Write") else ""
|
||
|
suffix = "<AI> def" if sample['question'].strip().startswith("Write") else "\n "
|
||
|
prompt = f"{prefix}{sample['question']}{suffix}"
|
||
|
prompts.append(prompt)
|
||
|
|
||
|
# 生成回复并处理
|
||
|
responses = generate_responses(prompts)
|
||
|
processed_outputs = []
|
||
|
|
||
|
for idx, response in enumerate(responses):
|
||
|
output = response.split('<AI>')[1].strip()
|
||
|
processed_outputs.append(output)
|
||
|
|
||
|
# 更新结果
|
||
|
for idx, output in enumerate(processed_outputs):
|
||
|
samples[batch_start + idx]['raw_outputs'] = [output]
|
||
|
|
||
|
# 保存结果
|
||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||
|
for sample in samples:
|
||
|
json.dump(sample, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
MODEL_PATH = 'checkpoints/js/checkpoint-7200'
|
||
|
INPUT_FILE = 'test_set/军事知识问答_round4.jsonl'
|
||
|
OUTPUT_FILE = 'test_set/军事知识问答_round4_output.jsonl'
|
||
|
execute_model_prediction(MODEL_PATH, INPUT_FILE, OUTPUT_FILE)
|