74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
|
from typing import List, Dict, Any
|
||
|
import json
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
||
|
def generate_responses(model: AutoModelForCausalLM,
|
||
|
tokenizer: AutoTokenizer,
|
||
|
texts: List[str]) -> List[str]:
|
||
|
"""批量生成模型响应"""
|
||
|
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
|
||
|
outputs = model.generate(inputs['input_ids'],
|
||
|
max_new_tokens=256,
|
||
|
pad_token_id=0)
|
||
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||
|
|
||
|
def format_prompt(sample: Dict[str, Any]) -> str:
|
||
|
"""根据输入样本格式化提示文本"""
|
||
|
if 'choices' in sample:
|
||
|
return f"<用户>{sample['question']}\n{''.join(sample['choices'])}<AI>"
|
||
|
|
||
|
if sample['question'].strip().startswith("Write"):
|
||
|
return f"<用户>{sample['question']}<AI> def"
|
||
|
|
||
|
return f"{sample['question']}\n "
|
||
|
|
||
|
def process_response(response: str, question: str) -> str:
|
||
|
"""处理模型输出的响应"""
|
||
|
return response.split('<AI>')[1].strip()
|
||
|
|
||
|
def run_inference(model_path: str, input_path: str, output_path: str):
|
||
|
"""执行推理流程"""
|
||
|
# 初始化模型和分词器
|
||
|
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||
|
model = model.to("cuda").to(dtype=torch.bfloat16)
|
||
|
model.eval()
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
tokenizer.pad_token = tokenizer.unk_token
|
||
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||
|
tokenizer.add_eos_token = False
|
||
|
|
||
|
# 读取数据
|
||
|
with open(input_path) as f:
|
||
|
samples = [json.loads(line) for line in f]
|
||
|
|
||
|
# 批量处理
|
||
|
batch_size = 8
|
||
|
for start_idx in tqdm(range(0, len(samples), batch_size)):
|
||
|
batch = samples[start_idx:start_idx + batch_size]
|
||
|
|
||
|
# 准备输入
|
||
|
prompts = [format_prompt(sample) for sample in batch]
|
||
|
|
||
|
# 生成响应
|
||
|
raw_outputs = generate_responses(model, tokenizer, prompts)
|
||
|
|
||
|
# 处理输出
|
||
|
for i, (sample, output) in enumerate(zip(batch, raw_outputs)):
|
||
|
processed_response = process_response(output, sample['question'])
|
||
|
samples[start_idx + i]['raw_outputs'] = [processed_response]
|
||
|
|
||
|
# 保存结果
|
||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||
|
for sample in samples:
|
||
|
json.dump(sample, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
MODEL_PATH = 'checkpoints/py/checkpoint-15000'
|
||
|
INPUT_FILE = 'test_set/代码生成_round4.jsonl'
|
||
|
OUTPUT_FILE = 'test_set/code_generation_result.jsonl'
|
||
|
|
||
|
run_inference(MODEL_PATH, INPUT_FILE, OUTPUT_FILE)
|