cc/inference.py

59 lines
2.2 KiB
Python
Raw Normal View History

2024-11-12 10:28:03 +08:00
import json
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def run(checkpoint_dir, src_file, dest_file):
model = AutoModelForCausalLM.from_pretrained(checkpoint_dir, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
model.eval()
tok = AutoTokenizer.from_pretrained(checkpoint_dir)
tok.pad_token = tok.unk_token
tok.pad_token_id = tok.unk_token_id
tok.add_eos_token = False
def process_batch(text_list):
encoded = tok(text_list, return_tensors="pt", padding=True).to(model.device)
generated = model.generate(encoded['input_ids'], max_new_tokens=256, pad_token_id=0)
decoded = tok.batch_decode(generated, skip_special_tokens=True)
return decoded
data_list = []
with open(src_file,'r') as fin:
for line in fin:
data_list.append(json.loads(line))
chunk_size = 8
for start_idx in tqdm(range(0, len(data_list), chunk_size)):
current_chunk = data_list[start_idx:start_idx+chunk_size]
if 'choices' in current_chunk[0]:
prompts = [f"<用户>{x['question']}\n{''.join(x['choices'])}<AI>" for x in current_chunk]
else:
prompts = []
for x in current_chunk:
if x['question'].strip().startswith("Write"):
prompts.append(f"<用户>{x['question']}<AI> def")
else:
prompts.append(f"{x['question']}\n ")
output_text = process_batch(prompts)
responses = []
for idx, text in enumerate(output_text):
response = text.split('<AI>')[1].strip()
responses.append(response)
for i, resp in enumerate(responses):
data_list[start_idx+i]['raw_outputs'] = [resp]
with open(dest_file, 'w', encoding='utf-8') as fout:
for item in data_list:
json.dump(item, fout, ensure_ascii=False)
fout.write('\n')
if __name__ == '__main__':
model_loc = 'checkpoints/py/checkpoint-5500'
in_file = 'test_set/代码生成_round4.jsonl'
out_file = 'test_set/daima_gen.jsonl'
run(model_loc, in_file, out_file)