From 58a7967a986307f65982e977759879cc20ee9c0d Mon Sep 17 00:00:00 2001 From: p83651209 Date: Sun, 3 Nov 2024 20:04:38 +0800 Subject: [PATCH] Update inference.py --- inference.py | 548 ++++++++++++++++++++++++++------------------------- 1 file changed, 278 insertions(+), 270 deletions(-) diff --git a/inference.py b/inference.py index 52dc5b2..40170c7 100644 --- a/inference.py +++ b/inference.py @@ -1,270 +1,278 @@ -import json, torch, re, sys, subprocess - -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, StoppingCriteria -device = "cuda" # the device to load the model onto -from tqdm import tqdm - - -def exec_code(test): - with open("test_case.json", "r") as f: - test_cases = json.load(f) - - right_num = 0 - all_num = 0 - package = "import os, sys, math, re, json, random\n" - - for item, test_case in zip(test, test_cases): - - - if "```python\n" in item["raw_outputs"]: - matches = re.findall('```python(.*?)```', item["raw_outputs"], re.DOTALL) - if len(matches) == 1: - item["raw_outputs"] = matches[0] - else: - matches = re.findall('```python(.*?)assert', item["raw_outputs"], re.DOTALL) - if len(matches) == 1: - item["raw_outputs"] = matches[0] - else: - item["raw_outputs"] = item["raw_outputs"][item["raw_outputs"].index("python\n") + len("python\n"):] - print(item) - #break - - - code = item["raw_outputs"].replace("<|im_end|>", "").replace("", "").replace("```", "").strip().rstrip("\n") - - raw_code = code - - codes = raw_code.split("\n") - last_line = 0 - for index, line in enumerate(codes): - if " return" in line: - last_line = index - - code = "\n".join(codes[:last_line+1]) - - - ''' - if raw_code != code: - print("\n--------------------------------------------------------\n", [raw_code], "\n--------------------------------------------------------\n") - print("clean:\n", [code], "\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n") - ''' - - - with open('code_.py', 'w') as fout: - fout.write(package + code + "\n" + "\n".join(test_case["test_case"])) - - batcmd = 'timeout 3 ' + sys.executable + ' code_.py' - try: - shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8') - right_num += 1 - item["result"] = "True" - except Exception as e: - print("++++++++++++++++++++++++++++++++++++++++++++++++++++\n", raw_code, "\n-----------------------------------------\n\n\n", package + code + "\n--------------------------\n" + "\n".join(test_case["test_case"])) - print("--------------------------------------------------------\n\n\nitem:", item) - print("e: ", e, "\n================================================\n")#, e, ) - item["result"] = "False" - - all_num += 1 - - item["raw_outputs"] = [code] - - print(len(test), right_num, all_num, right_num / all_num) - with open(f'wjf_{model_path.replace("/", "-")}{right_num / all_num}.json', "w") as f: - json.dump(test, f, indent=4) - - return test, right_num / all_num - - - -def get_result(model, tokenizer): - - test = [] - with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f: - #test = json.load(f) - for line in f: - test.append(json.loads(line)) - - all_score = 0 - all_num = 0 - test_num = 1000 - - - from tqdm import tqdm - for example in tqdm(test[:]): - #print(example["question"]) - example["question"] = example["question"].replace("'''", '"""') - - ai_prefix = "" - if example["question"].split(" ")[0] == "Write": - question = example["question"][:example["question"].index("\n")].strip().rstrip() - test_case = example["question"][example["question"].index("\n"):].split("\n") - print("test_case: ", test_case) - - function_name = test_case[1].split(" ")[1].split("(")[0] - ai_prefix = "def " + function_name - - messages = [ - {"role": "user", "content": question + "\n\n" + ("\n".join(test_case))} - ] - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - text += ai_prefix - example["test_case"] = test_case - - - else: - tmp = re.findall(r'"""(.*?)"""', example["question"], flags=re.DOTALL)[0].split("\n") - question = "" - for line in tmp: - line = line.strip().rstrip() - if len(line) == 0: - continue - - #if "xample" in line and len(line) < 20: - # break - - question += line + " " - - code = re.sub(r'"""(.*?)"""', '', example["question"], flags=re.DOTALL).strip().rstrip() - ai_prefix = code - - messages = [ - {"role": "user", "content": question} - ] - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - text += ai_prefix - - - example["prompt"] = text - print("text: " , [text]) - input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids - - - output = model.generate(input_ids, - #top_p=1.0, - max_new_tokens=600, - #repetition_penalty=1.1 + t*0.01, - temperature=0.1, - #no_repeat_ngram_size = 5, - ).squeeze() - - output_str = tokenizer.decode(output[input_ids.shape[1]:]) - output_str = ai_prefix + output_str - print("output_str:\n", output_str, "\n-----------------------------------------------------------------") - - example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str) - return test - - - -def get_result_1(model, tokenizer): - - test = [] - with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f: - #test = json.load(f) - for line in f: - test.append(json.loads(line)) - - all_score = 0 - all_num = 0 - test_num = 1000 - - - from tqdm import tqdm - for example in tqdm(test[:]): - #print(example["question"]) - messages = [ - {"role": "user", "content": example["question"]} - ] - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - - example["prompt"] = text - print("text: " , [text]) - input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids - - - output = model.generate(input_ids, - #top_p=1.0, - max_new_tokens=600, - #repetition_penalty=1.1 + t*0.01, - temperature=0.1, - #no_repeat_ngram_size = 5, - ).squeeze() - output_str = tokenizer.decode(output[input_ids.shape[1]:]) - print("output_str:\n", output_str, "\n-----------------------------------------------------------------") - - example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str) - return test - - -answers = {} -for model_path in [ - "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_1/checkpoint-600", - "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new/checkpoint-600/", - -]: - print("model_path: ", model_path) - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype="auto", - device_map=device, - trust_remote_code=True, - ) - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - - test = get_result(model, tokenizer) - test, score = exec_code(test) - answers[score] = test - - - test = get_result_1(model, tokenizer) - test, score = exec_code(test) - answers[score] = test - - - -answers = list(dict(sorted(answers.items())).values()) -print("answers: ", answers) -right = 0 -jiuge_right = 0 -merge = [] -for i in range(len(answers)): -#for i in range(2): - flag = 0 - for answer in answers: - if answer[i]["result"] == "True": - right += 1 - jiuge_right += 1 - flag = 1 - merge.append(answer[i]) - break - - if flag == 0: - merge.append(answers[0][i]) - - - -print(right / len(answers), jiuge_right / len(answers)) -with open("wjf_jiuge.jsonl", "w") as f: - for item in merge: - item.pop("result") - f.write(json.dumps(item, ensure_ascii=False) + '\n') - - - - - - +import json, torch, re, sys, subprocess + +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, StoppingCriteria +device = "cuda" # the device to load the model onto +from tqdm import tqdm + + +def exec_code(test): + with open("test_case.json", "r") as f: + test_cases = json.load(f) + + right_num = 0 + all_num = 0 + package = "import os, sys, math, re, json, random\n" + + for item, test_case in zip(test, test_cases): + + + if "```python\n" in item["raw_outputs"]: + matches = re.findall('```python(.*?)```', item["raw_outputs"], re.DOTALL) + if len(matches) == 1: + item["raw_outputs"] = matches[0] + else: + matches = re.findall('```python(.*?)assert', item["raw_outputs"], re.DOTALL) + if len(matches) == 1: + item["raw_outputs"] = matches[0] + else: + item["raw_outputs"] = item["raw_outputs"][item["raw_outputs"].index("python\n") + len("python\n"):] + print(item) + #break + + + code = item["raw_outputs"].replace("<|im_end|>", "").replace("", "").replace("```", "").strip().rstrip("\n") + + raw_code = code + + codes = raw_code.split("\n") + last_line = 0 + for index, line in enumerate(codes): + if " return" in line: + last_line = index + + code = "\n".join(codes[:last_line+1]) + + + ''' + if raw_code != code: + print("\n--------------------------------------------------------\n", [raw_code], "\n--------------------------------------------------------\n") + print("clean:\n", [code], "\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n") + ''' + + + with open('code_.py', 'w') as fout: + fout.write(package + code + "\n" + "\n".join(test_case["test_case"])) + + batcmd = 'timeout 3 ' + sys.executable + ' code_.py' + try: + shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8') + right_num += 1 + item["result"] = "True" + except Exception as e: + print("++++++++++++++++++++++++++++++++++++++++++++++++++++\n", raw_code, "\n-----------------------------------------\n\n\n", package + code + "\n--------------------------\n" + "\n".join(test_case["test_case"])) + print("--------------------------------------------------------\n\n\nitem:", item) + print("e: ", e, "\n================================================\n")#, e, ) + item["result"] = "False" + + all_num += 1 + + item["raw_outputs"] = [code] + + print(len(test), right_num, all_num, right_num / all_num) + with open(f'wjf_{model_path.replace("/", "-")}{right_num / all_num}.json', "w") as f: + json.dump(test, f, indent=4) + + return test, right_num / all_num + + + +def get_result(model, tokenizer): + + test = [] + with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f: + #test = json.load(f) + for line in f: + test.append(json.loads(line)) + + all_score = 0 + all_num = 0 + test_num = 1000 + + + from tqdm import tqdm + for example in tqdm(test[:]): + #print(example["question"]) + example["question"] = example["question"].replace("'''", '"""') + + ai_prefix = "" + if example["question"].split(" ")[0] == "Write": + question = example["question"][:example["question"].index("\n")].strip().rstrip() + test_case = example["question"][example["question"].index("\n"):].split("\n") + print("test_case: ", test_case) + + function_name = test_case[1].split(" ")[1].split("(")[0] + ai_prefix = "def " + function_name + + messages = [ + {"role": "user", "content": question + "\n\n" + ("\n".join(test_case))} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + text += ai_prefix + example["test_case"] = test_case + + + else: + tmp = re.findall(r'"""(.*?)"""', example["question"], flags=re.DOTALL)[0].split("\n") + question = "" + for line in tmp: + line = line.strip().rstrip() + if len(line) == 0: + continue + + #if "xample" in line and len(line) < 20: + # break + + question += line + " " + + code = re.sub(r'"""(.*?)"""', '', example["question"], flags=re.DOTALL).strip().rstrip() + ai_prefix = code + + messages = [ + {"role": "user", "content": question} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + text += ai_prefix + + + example["prompt"] = text + print("text: " , [text]) + input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids + + + output = model.generate(input_ids, + #top_p=1.0, + max_new_tokens=600, + #repetition_penalty=1.1 + t*0.01, + temperature=0.1, + #no_repeat_ngram_size = 5, + ).squeeze() + + output_str = tokenizer.decode(output[input_ids.shape[1]:]) + output_str = ai_prefix + output_str + print("output_str:\n", output_str, "\n-----------------------------------------------------------------") + + example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str) + return test + + + +def get_result_1(model, tokenizer): + + test = [] + with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f: + #test = json.load(f) + for line in f: + test.append(json.loads(line)) + + all_score = 0 + all_num = 0 + test_num = 1000 + + + from tqdm import tqdm + for example in tqdm(test[:]): + #print(example["question"]) + messages = [ + {"role": "user", "content": example["question"]} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + + example["prompt"] = text + print("text: " , [text]) + input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids + + + output = model.generate(input_ids, + #top_p=1.0, + max_new_tokens=600, + #repetition_penalty=1.1 + t*0.01, + temperature=0.1, + #no_repeat_ngram_size = 5, + ).squeeze() + output_str = tokenizer.decode(output[input_ids.shape[1]:]) + print("output_str:\n", output_str, "\n-----------------------------------------------------------------") + + example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str) + return test + + +answers = {} + +for model_path in [ + "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all/TACO/", + "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_2/CodeNet4Repair/", + "/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_1/CodeExercise-Python-27k/", +]: + print("model_path: ", model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map=device, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + test = get_result(model, tokenizer) + test, score = exec_code(test) + answers[score] = test + + + test = get_result_1(model, tokenizer) + test, score = exec_code(test) + answers[score] = test + +''' +import os +for path in os.listdir("./"): + if "home-wujianfeng" in path: + with open(path, "r") as f: + test = json.load(f) + answers[float(path.split(".")[-2].split("-")[-1])] = test +''' + +answers = list(dict(sorted(answers.items())).values()) +print("answers: ", answers) +right = 0 +jiuge_right = 0 +merge = [] +for i in range(len(answers[0])): +#for i in range(2): + flag = 0 + for answer in answers: + if answer[i]["result"] == "True": + right += 1 + jiuge_right += 1 + flag = 1 + merge.append(answer[i]) + break + + if flag == 0: + merge.append(answers[0][i]) + + + +print(right / len(answers[0]), jiuge_right / len(answers[0])) +with open("wjf_jiuge.jsonl", "w") as f: + for item in merge: + item.pop("result") + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + + + + +