import json import torch from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer def execute_model_prediction(model_path: str, input_file: str, output_file: str): # ... existing imports ... # 初始化模型和分词器 model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True ).to("cuda", dtype=torch.bfloat16) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.add_eos_token = False def generate_responses(prompts: list[str]) -> list[str]: """处理一批文本并生成回复""" inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) outputs = model.generate(inputs['input_ids'], max_new_tokens=256, pad_token_id=0) return tokenizer.batch_decode(outputs, skip_special_tokens=True) # 读取数据 samples = [] with open(input_file, 'r') as f: for line in f: samples.append(json.loads(line)) BATCH_SIZE = 16 for batch_start in tqdm(range(0, len(samples), BATCH_SIZE)): batch = samples[batch_start:batch_start + BATCH_SIZE] # 构建输入提示 prompts = [] for sample in batch: if 'choices' in sample: prompt = f"<用户>{sample['question']}\n{''.join(sample['choices'])}" else: prefix = "<用户>" if sample['question'].strip().startswith("Write") else "" suffix = " def" if sample['question'].strip().startswith("Write") else "\n " prompt = f"{prefix}{sample['question']}{suffix}" prompts.append(prompt) # 生成回复并处理 responses = generate_responses(prompts) processed_outputs = [] for idx, response in enumerate(responses): output = response.split('')[1].strip() processed_outputs.append(output) # 更新结果 for idx, output in enumerate(processed_outputs): samples[batch_start + idx]['raw_outputs'] = [output] # 保存结果 with open(output_file, 'w', encoding='utf-8') as f: for sample in samples: json.dump(sample, f, ensure_ascii=False) f.write('\n') if __name__ == '__main__': MODEL_PATH = 'checkpoints/js/checkpoint-7200' INPUT_FILE = 'test_set/军事知识问答_round4.jsonl' OUTPUT_FILE = 'test_set/军事知识问答_round4_output.jsonl' execute_model_prediction(MODEL_PATH, INPUT_FILE, OUTPUT_FILE)