from typing import List, Dict, Any import json import torch from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer # 常量定义 MAX_NEW_TOKENS = 256 BATCH_SIZE = 16 DEVICE = "cuda" DTYPE = torch.bfloat16 def load_model_and_tokenizer(checkpoint_path: str): """加载模型和分词器""" model = AutoModelForCausalLM.from_pretrained( checkpoint_path, trust_remote_code=True ).to(DEVICE, dtype=DTYPE) model.eval() tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.add_eos_token = False return model, tokenizer def format_prompt(sample: Dict[str, Any]) -> str: """格式化输入提示""" if 'choices' in sample: return f"<用户>{sample['question']}\n{''.join(sample['choices'])}" if sample['question'].strip().startswith("Write"): return f"<用户>{sample['question']} def" return f"{sample['question']}\n " def process_response(response_text: str, question: str) -> str: """处理模型输出""" return response_text.split('')[1].strip() def generate_responses(model, tokenizer, text_inputs: List[str]) -> List[str]: """生成模型回复""" encoded = tokenizer(text_inputs, return_tensors="pt", padding=True).to(model.device) generated = model.generate( encoded['input_ids'], max_new_tokens=MAX_NEW_TOKENS, pad_token_id=0 ) return tokenizer.batch_decode(generated, skip_special_tokens=True) def execute_model(checkpoint_dir: str, src_file: str, dest_file: str) -> None: """主执行函数""" # 加载模型和分词器 model, tokenizer = load_model_and_tokenizer(checkpoint_dir) # 读取输入数据 with open(src_file, 'r') as fin: samples = [json.loads(line) for line in fin] # 批量处理数据 for start_idx in tqdm(range(0, len(samples), BATCH_SIZE)): batch = samples[start_idx:start_idx + BATCH_SIZE] # 准备输入 prompts = [format_prompt(sample) for sample in batch] # 生成回复 outputs = generate_responses(model, tokenizer, prompts) # 处理输出 for i, (output, sample) in enumerate(zip(outputs, batch)): response = process_response(output, sample['question']) samples[start_idx + i]['raw_outputs'] = [response] # 保存结果 with open(dest_file, 'w', encoding='utf-8') as fout: for sample in samples: json.dump(sample, fout, ensure_ascii=False) fout.write('\n') if __name__ == '__main__': model_path = 'checkpoints/py/checkpoint-7000' input_file = 'test_set/代码生成_round4.jsonl' output_file = 'test_set/代码生成_round4_result.jsonl' execute_model(model_path, input_file, output_file)