59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
|
import json
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
||
|
def run(checkpoint_dir, src_file, dest_file):
|
||
|
model = AutoModelForCausalLM.from_pretrained(checkpoint_dir, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
|
||
|
model.eval()
|
||
|
|
||
|
tok = AutoTokenizer.from_pretrained(checkpoint_dir)
|
||
|
tok.pad_token = tok.unk_token
|
||
|
tok.pad_token_id = tok.unk_token_id
|
||
|
tok.add_eos_token = False
|
||
|
|
||
|
def process_batch(text_list):
|
||
|
encoded = tok(text_list, return_tensors="pt", padding=True).to(model.device)
|
||
|
generated = model.generate(encoded['input_ids'], max_new_tokens=256, pad_token_id=0)
|
||
|
decoded = tok.batch_decode(generated, skip_special_tokens=True)
|
||
|
return decoded
|
||
|
|
||
|
data_list = []
|
||
|
with open(src_file,'r') as fin:
|
||
|
for line in fin:
|
||
|
data_list.append(json.loads(line))
|
||
|
|
||
|
chunk_size = 8
|
||
|
for start_idx in tqdm(range(0, len(data_list), chunk_size)):
|
||
|
current_chunk = data_list[start_idx:start_idx+chunk_size]
|
||
|
|
||
|
if 'choices' in current_chunk[0]:
|
||
|
prompts = [f"<用户>{x['question']}\n{''.join(x['choices'])}<AI>" for x in current_chunk]
|
||
|
else:
|
||
|
prompts = []
|
||
|
for x in current_chunk:
|
||
|
if x['question'].strip().startswith("Write"):
|
||
|
prompts.append(f"<用户>{x['question']}<AI> def")
|
||
|
else:
|
||
|
prompts.append(f"{x['question']}\n ")
|
||
|
|
||
|
output_text = process_batch(prompts)
|
||
|
|
||
|
responses = []
|
||
|
for idx, text in enumerate(output_text):
|
||
|
response = text.split('<AI>')[1].strip()
|
||
|
responses.append(response)
|
||
|
|
||
|
for i, resp in enumerate(responses):
|
||
|
data_list[start_idx+i]['raw_outputs'] = [resp]
|
||
|
|
||
|
with open(dest_file, 'w', encoding='utf-8') as fout:
|
||
|
for item in data_list:
|
||
|
json.dump(item, fout, ensure_ascii=False)
|
||
|
fout.write('\n')
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
model_loc = 'checkpoints/py/checkpoint-5500'
|
||
|
in_file = 'test_set/代码生成_round4.jsonl'
|
||
|
out_file = 'test_set/daima_gen.jsonl'
|
||
|
run(model_loc, in_file, out_file)
|