From defb7a8bdddb48fcc8c5ad2eb1bc94217e13e1ad Mon Sep 17 00:00:00 2001 From: p83651209 <1550266278@qq.com> Date: Sat, 2 Nov 2024 16:15:33 +0800 Subject: [PATCH] ADD file via upload --- inference.py | 270 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 inference.py diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..52dc5b2 --- /dev/null +++ b/inference.py @@ -0,0 +1,270 @@ +import json, torch, re, sys, subprocess + +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, StoppingCriteria +device = "cuda" # the device to load the model onto +from tqdm import tqdm + + +def exec_code(test): + with open("test_case.json", "r") as f: + test_cases = json.load(f) + + right_num = 0 + all_num = 0 + package = "import os, sys, math, re, json, random\n" + + for item, test_case in zip(test, test_cases): + + + if "```python\n" in item["raw_outputs"]: + matches = re.findall('```python(.*?)```', item["raw_outputs"], re.DOTALL) + if len(matches) == 1: + item["raw_outputs"] = matches[0] + else: + matches = re.findall('```python(.*?)assert', item["raw_outputs"], re.DOTALL) + if len(matches) == 1: + item["raw_outputs"] = matches[0] + else: + item["raw_outputs"] = item["raw_outputs"][item["raw_outputs"].index("python\n") + len("python\n"):] + print(item) + #break + + + code = item["raw_outputs"].replace("<|im_end|>", "").replace("", "").replace("```", "").strip().rstrip("\n") + + raw_code = code + + codes = raw_code.split("\n") + last_line = 0 + for index, line in enumerate(codes): + if " return" in line: + last_line = index + + code = "\n".join(codes[:last_line+1]) + + + ''' + if raw_code != code: + print("\n--------------------------------------------------------\n", [raw_code], "\n--------------------------------------------------------\n") + print("clean:\n", [code], "\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n") + ''' + + + with open('code_.py', 'w') as fout: + fout.write(package + code + "\n" + "\n".join(test_case["test_case"])) + + batcmd = 'timeout 3 ' + sys.executable + ' code_.py' + try: + shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8') + right_num += 1 + item["result"] = "True" + except Exception as e: + print("++++++++++++++++++++++++++++++++++++++++++++++++++++\n", raw_code, "\n-----------------------------------------\n\n\n", package + code + "\n--------------------------\n" + "\n".join(test_case["test_case"])) + print("--------------------------------------------------------\n\n\nitem:", item) + print("e: ", e, "\n================================================\n")#, e, ) + item["result"] = "False" + + all_num += 1 + + item["raw_outputs"] = [code] + + print(len(test), right_num, all_num, right_num / all_num) + with open(f'wjf_{model_path.replace("/", "-")}{right_num / all_num}.json', "w") as f: + json.dump(test, f, indent=4) + + return test, right_num / all_num + + + +def get_result(model, tokenizer): + + test = [] + with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f: + #test = json.load(f) + for line in f: + test.append(json.loads(line)) + + all_score = 0 + all_num = 0 + test_num = 1000 + + + from tqdm import tqdm + for example in tqdm(test[:]): + #print(example["question"]) + example["question"] = example["question"].replace("'''", '"""') + + ai_prefix = "" + if example["question"].split(" ")[0] == "Write": + question = example["question"][:example["question"].index("\n")].strip().rstrip() + test_case = example["question"][example["question"].index("\n"):].split("\n") + print("test_case: ", test_case) + + function_name = test_case[1].split(" ")[1].split("(")[0] + ai_prefix = "def " + function_name + + messages = [ + {"role": "user", "content": question + "\n\n" + ("\n".join(test_case))} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + text += ai_prefix + example["test_case"] = test_case + + + else: + tmp = re.findall(r'"""(.*?)"""', example["question"], flags=re.DOTALL)[0].split("\n") + question = "" + for line in tmp: + line = line.strip().rstrip() + if len(line) == 0: + continue + + #if "xample" in line and len(line) < 20: + # break + + question += line + " " + + code = re.sub(r'"""(.*?)"""', '', example["question"], flags=re.DOTALL).strip().rstrip() + ai_prefix = code + + messages = [ + {"role": "user", "content": question} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + text += ai_prefix + + + example["prompt"] = text + print("text: " , [text]) + input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids + + + output = model.generate(input_ids, + #top_p=1.0, + max_new_tokens=600, + #repetition_penalty=1.1 + t*0.01, + temperature=0.1, + #no_repeat_ngram_size = 5, + ).squeeze() + + output_str = tokenizer.decode(output[input_ids.shape[1]:]) + output_str = ai_prefix + output_str + print("output_str:\n", output_str, "\n-----------------------------------------------------------------") + + example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str) + return test + + + +def get_result_1(model, tokenizer): + + test = [] + with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f: + #test = json.load(f) + for line in f: + test.append(json.loads(line)) + + all_score = 0 + all_num = 0 + test_num = 1000 + + + from tqdm import tqdm + for example in tqdm(test[:]): + #print(example["question"]) + messages = [ + {"role": "user", "content": example["question"]} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + + example["prompt"] = text + print("text: " , [text]) + input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids + + + output = model.generate(input_ids, + #top_p=1.0, + max_new_tokens=600, + #repetition_penalty=1.1 + t*0.01, + temperature=0.1, + #no_repeat_ngram_size = 5, + ).squeeze() + output_str = tokenizer.decode(output[input_ids.shape[1]:]) + print("output_str:\n", output_str, "\n-----------------------------------------------------------------") + + example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str) + return test + + +answers = {} +for model_path in [ + "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_1/checkpoint-600", + "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new/checkpoint-600/", + +]: + print("model_path: ", model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map=device, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + test = get_result(model, tokenizer) + test, score = exec_code(test) + answers[score] = test + + + test = get_result_1(model, tokenizer) + test, score = exec_code(test) + answers[score] = test + + + +answers = list(dict(sorted(answers.items())).values()) +print("answers: ", answers) +right = 0 +jiuge_right = 0 +merge = [] +for i in range(len(answers)): +#for i in range(2): + flag = 0 + for answer in answers: + if answer[i]["result"] == "True": + right += 1 + jiuge_right += 1 + flag = 1 + merge.append(answer[i]) + break + + if flag == 0: + merge.append(answers[0][i]) + + + +print(right / len(answers), jiuge_right / len(answers)) +with open("wjf_jiuge.jsonl", "w") as f: + for item in merge: + item.pop("result") + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + + + + +