86 lines
2.9 KiB
Python
Executable File
86 lines
2.9 KiB
Python
Executable File
from typing import List, Dict, Any
|
|
import json
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
# 常量定义
|
|
MAX_NEW_TOKENS = 256
|
|
BATCH_SIZE = 16
|
|
DEVICE = "cuda"
|
|
DTYPE = torch.bfloat16
|
|
|
|
def load_model_and_tokenizer(checkpoint_path: str):
|
|
"""加载模型和分词器"""
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
checkpoint_path,
|
|
trust_remote_code=True
|
|
).to(DEVICE, dtype=DTYPE)
|
|
model.eval()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
|
tokenizer.pad_token = tokenizer.unk_token
|
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
|
tokenizer.add_eos_token = False
|
|
|
|
return model, tokenizer
|
|
|
|
def format_prompt(sample: Dict[str, Any]) -> str:
|
|
"""格式化输入提示"""
|
|
if 'choices' in sample:
|
|
return f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
|
|
|
|
if sample['question'].strip().startswith("Write"):
|
|
return f"<用户>{sample['question']}<AI> def"
|
|
|
|
return f"{sample['question']}\n "
|
|
|
|
def process_response(response_text: str, question: str) -> str:
|
|
"""处理模型输出"""
|
|
return response_text.split('<AI>')[1].strip()
|
|
|
|
def generate_responses(model, tokenizer, text_inputs: List[str]) -> List[str]:
|
|
"""生成模型回复"""
|
|
encoded = tokenizer(text_inputs, return_tensors="pt", padding=True).to(model.device)
|
|
generated = model.generate(
|
|
encoded['input_ids'],
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
pad_token_id=0
|
|
)
|
|
return tokenizer.batch_decode(generated, skip_special_tokens=True)
|
|
|
|
def execute_model(checkpoint_dir: str, src_file: str, dest_file: str) -> None:
|
|
"""主执行函数"""
|
|
# 加载模型和分词器
|
|
model, tokenizer = load_model_and_tokenizer(checkpoint_dir)
|
|
|
|
# 读取输入数据
|
|
with open(src_file, 'r') as fin:
|
|
samples = [json.loads(line) for line in fin]
|
|
|
|
# 批量处理数据
|
|
for start_idx in tqdm(range(0, len(samples), BATCH_SIZE)):
|
|
batch = samples[start_idx:start_idx + BATCH_SIZE]
|
|
|
|
# 准备输入
|
|
prompts = [format_prompt(sample) for sample in batch]
|
|
|
|
# 生成回复
|
|
outputs = generate_responses(model, tokenizer, prompts)
|
|
|
|
# 处理输出
|
|
for i, (output, sample) in enumerate(zip(outputs, batch)):
|
|
response = process_response(output, sample['question'])
|
|
samples[start_idx + i]['raw_outputs'] = [response]
|
|
|
|
# 保存结果
|
|
with open(dest_file, 'w', encoding='utf-8') as fout:
|
|
for sample in samples:
|
|
json.dump(sample, fout, ensure_ascii=False)
|
|
fout.write('\n')
|
|
|
|
if __name__ == '__main__':
|
|
model_path = 'checkpoints/js/checkpoint-6200'
|
|
input_file = 'test_set/军事知识问答_round4.jsonl'
|
|
output_file = 'test_set/军事知识问答_round4_result.jsonl'
|
|
execute_model(model_path, input_file, output_file) |