import json import torch from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer def run(checkpoint_dir, src_file, dest_file): model = AutoModelForCausalLM.from_pretrained(checkpoint_dir, trust_remote_code=True).to("cuda", dtype=torch.bfloat16) model.eval() tok = AutoTokenizer.from_pretrained(checkpoint_dir) tok.pad_token = tok.unk_token tok.pad_token_id = tok.unk_token_id tok.add_eos_token = False def process_batch(text_list): encoded = tok(text_list, return_tensors="pt", padding=True).to(model.device) generated = model.generate(encoded['input_ids'], max_new_tokens=256, pad_token_id=0) decoded = tok.batch_decode(generated, skip_special_tokens=True) return decoded data_list = [] with open(src_file,'r') as fin: for line in fin: data_list.append(json.loads(line)) chunk_size = 8 for start_idx in tqdm(range(0, len(data_list), chunk_size)): current_chunk = data_list[start_idx:start_idx+chunk_size] if 'choices' in current_chunk[0]: prompts = [f"<用户>{x['question']}\n{''.join(x['choices'])}" for x in current_chunk] else: prompts = [] for x in current_chunk: if x['question'].strip().startswith("Write"): prompts.append(f"<用户>{x['question']} def") else: prompts.append(f"{x['question']}\n ") output_text = process_batch(prompts) responses = [] for idx, text in enumerate(output_text): response = text.split('')[1].strip() responses.append(response) for i, resp in enumerate(responses): data_list[start_idx+i]['raw_outputs'] = [resp] with open(dest_file, 'w', encoding='utf-8') as fout: for item in data_list: json.dump(item, fout, ensure_ascii=False) fout.write('\n') if __name__ == '__main__': model_loc = 'checkpoints/py/checkpoint-5500' in_file = 'test_set/代码生成_round4.jsonl' out_file = 'test_set/daima_gen.jsonl' run(model_loc, in_file, out_file)