forked from jiuyuan/CPM-9G-8B
Update inference.py
This commit is contained in:
parent
cd1bdcf117
commit
58a7967a98
548
inference.py
548
inference.py
|
@ -1,270 +1,278 @@
|
||||||
import json, torch, re, sys, subprocess
|
import json, torch, re, sys, subprocess
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, StoppingCriteria
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, StoppingCriteria
|
||||||
device = "cuda" # the device to load the model onto
|
device = "cuda" # the device to load the model onto
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def exec_code(test):
|
def exec_code(test):
|
||||||
with open("test_case.json", "r") as f:
|
with open("test_case.json", "r") as f:
|
||||||
test_cases = json.load(f)
|
test_cases = json.load(f)
|
||||||
|
|
||||||
right_num = 0
|
right_num = 0
|
||||||
all_num = 0
|
all_num = 0
|
||||||
package = "import os, sys, math, re, json, random\n"
|
package = "import os, sys, math, re, json, random\n"
|
||||||
|
|
||||||
for item, test_case in zip(test, test_cases):
|
for item, test_case in zip(test, test_cases):
|
||||||
|
|
||||||
|
|
||||||
if "```python\n" in item["raw_outputs"]:
|
if "```python\n" in item["raw_outputs"]:
|
||||||
matches = re.findall('```python(.*?)```', item["raw_outputs"], re.DOTALL)
|
matches = re.findall('```python(.*?)```', item["raw_outputs"], re.DOTALL)
|
||||||
if len(matches) == 1:
|
if len(matches) == 1:
|
||||||
item["raw_outputs"] = matches[0]
|
item["raw_outputs"] = matches[0]
|
||||||
else:
|
else:
|
||||||
matches = re.findall('```python(.*?)assert', item["raw_outputs"], re.DOTALL)
|
matches = re.findall('```python(.*?)assert', item["raw_outputs"], re.DOTALL)
|
||||||
if len(matches) == 1:
|
if len(matches) == 1:
|
||||||
item["raw_outputs"] = matches[0]
|
item["raw_outputs"] = matches[0]
|
||||||
else:
|
else:
|
||||||
item["raw_outputs"] = item["raw_outputs"][item["raw_outputs"].index("python\n") + len("python\n"):]
|
item["raw_outputs"] = item["raw_outputs"][item["raw_outputs"].index("python\n") + len("python\n"):]
|
||||||
print(item)
|
print(item)
|
||||||
#break
|
#break
|
||||||
|
|
||||||
|
|
||||||
code = item["raw_outputs"].replace("<|im_end|>", "").replace("</s>", "").replace("```", "").strip().rstrip("\n")
|
code = item["raw_outputs"].replace("<|im_end|>", "").replace("</s>", "").replace("```", "").strip().rstrip("\n")
|
||||||
|
|
||||||
raw_code = code
|
raw_code = code
|
||||||
|
|
||||||
codes = raw_code.split("\n")
|
codes = raw_code.split("\n")
|
||||||
last_line = 0
|
last_line = 0
|
||||||
for index, line in enumerate(codes):
|
for index, line in enumerate(codes):
|
||||||
if " return" in line:
|
if " return" in line:
|
||||||
last_line = index
|
last_line = index
|
||||||
|
|
||||||
code = "\n".join(codes[:last_line+1])
|
code = "\n".join(codes[:last_line+1])
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if raw_code != code:
|
if raw_code != code:
|
||||||
print("\n--------------------------------------------------------\n", [raw_code], "\n--------------------------------------------------------\n")
|
print("\n--------------------------------------------------------\n", [raw_code], "\n--------------------------------------------------------\n")
|
||||||
print("clean:\n", [code], "\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n")
|
print("clean:\n", [code], "\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n")
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
with open('code_.py', 'w') as fout:
|
with open('code_.py', 'w') as fout:
|
||||||
fout.write(package + code + "\n" + "\n".join(test_case["test_case"]))
|
fout.write(package + code + "\n" + "\n".join(test_case["test_case"]))
|
||||||
|
|
||||||
batcmd = 'timeout 3 ' + sys.executable + ' code_.py'
|
batcmd = 'timeout 3 ' + sys.executable + ' code_.py'
|
||||||
try:
|
try:
|
||||||
shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8')
|
shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8')
|
||||||
right_num += 1
|
right_num += 1
|
||||||
item["result"] = "True"
|
item["result"] = "True"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("++++++++++++++++++++++++++++++++++++++++++++++++++++\n", raw_code, "\n-----------------------------------------\n\n\n", package + code + "\n--------------------------\n" + "\n".join(test_case["test_case"]))
|
print("++++++++++++++++++++++++++++++++++++++++++++++++++++\n", raw_code, "\n-----------------------------------------\n\n\n", package + code + "\n--------------------------\n" + "\n".join(test_case["test_case"]))
|
||||||
print("--------------------------------------------------------\n\n\nitem:", item)
|
print("--------------------------------------------------------\n\n\nitem:", item)
|
||||||
print("e: ", e, "\n================================================\n")#, e, )
|
print("e: ", e, "\n================================================\n")#, e, )
|
||||||
item["result"] = "False"
|
item["result"] = "False"
|
||||||
|
|
||||||
all_num += 1
|
all_num += 1
|
||||||
|
|
||||||
item["raw_outputs"] = [code]
|
item["raw_outputs"] = [code]
|
||||||
|
|
||||||
print(len(test), right_num, all_num, right_num / all_num)
|
print(len(test), right_num, all_num, right_num / all_num)
|
||||||
with open(f'wjf_{model_path.replace("/", "-")}{right_num / all_num}.json', "w") as f:
|
with open(f'wjf_{model_path.replace("/", "-")}{right_num / all_num}.json', "w") as f:
|
||||||
json.dump(test, f, indent=4)
|
json.dump(test, f, indent=4)
|
||||||
|
|
||||||
return test, right_num / all_num
|
return test, right_num / all_num
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_result(model, tokenizer):
|
def get_result(model, tokenizer):
|
||||||
|
|
||||||
test = []
|
test = []
|
||||||
with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f:
|
with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f:
|
||||||
#test = json.load(f)
|
#test = json.load(f)
|
||||||
for line in f:
|
for line in f:
|
||||||
test.append(json.loads(line))
|
test.append(json.loads(line))
|
||||||
|
|
||||||
all_score = 0
|
all_score = 0
|
||||||
all_num = 0
|
all_num = 0
|
||||||
test_num = 1000
|
test_num = 1000
|
||||||
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
for example in tqdm(test[:]):
|
for example in tqdm(test[:]):
|
||||||
#print(example["question"])
|
#print(example["question"])
|
||||||
example["question"] = example["question"].replace("'''", '"""')
|
example["question"] = example["question"].replace("'''", '"""')
|
||||||
|
|
||||||
ai_prefix = ""
|
ai_prefix = ""
|
||||||
if example["question"].split(" ")[0] == "Write":
|
if example["question"].split(" ")[0] == "Write":
|
||||||
question = example["question"][:example["question"].index("\n")].strip().rstrip()
|
question = example["question"][:example["question"].index("\n")].strip().rstrip()
|
||||||
test_case = example["question"][example["question"].index("\n"):].split("\n")
|
test_case = example["question"][example["question"].index("\n"):].split("\n")
|
||||||
print("test_case: ", test_case)
|
print("test_case: ", test_case)
|
||||||
|
|
||||||
function_name = test_case[1].split(" ")[1].split("(")[0]
|
function_name = test_case[1].split(" ")[1].split("(")[0]
|
||||||
ai_prefix = "def " + function_name
|
ai_prefix = "def " + function_name
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": question + "\n\n" + ("\n".join(test_case))}
|
{"role": "user", "content": question + "\n\n" + ("\n".join(test_case))}
|
||||||
]
|
]
|
||||||
text = tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True
|
add_generation_prompt=True
|
||||||
)
|
)
|
||||||
text += ai_prefix
|
text += ai_prefix
|
||||||
example["test_case"] = test_case
|
example["test_case"] = test_case
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tmp = re.findall(r'"""(.*?)"""', example["question"], flags=re.DOTALL)[0].split("\n")
|
tmp = re.findall(r'"""(.*?)"""', example["question"], flags=re.DOTALL)[0].split("\n")
|
||||||
question = ""
|
question = ""
|
||||||
for line in tmp:
|
for line in tmp:
|
||||||
line = line.strip().rstrip()
|
line = line.strip().rstrip()
|
||||||
if len(line) == 0:
|
if len(line) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
#if "xample" in line and len(line) < 20:
|
#if "xample" in line and len(line) < 20:
|
||||||
# break
|
# break
|
||||||
|
|
||||||
question += line + " "
|
question += line + " "
|
||||||
|
|
||||||
code = re.sub(r'"""(.*?)"""', '', example["question"], flags=re.DOTALL).strip().rstrip()
|
code = re.sub(r'"""(.*?)"""', '', example["question"], flags=re.DOTALL).strip().rstrip()
|
||||||
ai_prefix = code
|
ai_prefix = code
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": question}
|
{"role": "user", "content": question}
|
||||||
]
|
]
|
||||||
text = tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True
|
add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
text += ai_prefix
|
text += ai_prefix
|
||||||
|
|
||||||
|
|
||||||
example["prompt"] = text
|
example["prompt"] = text
|
||||||
print("text: " , [text])
|
print("text: " , [text])
|
||||||
input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids
|
input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids
|
||||||
|
|
||||||
|
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
#top_p=1.0,
|
#top_p=1.0,
|
||||||
max_new_tokens=600,
|
max_new_tokens=600,
|
||||||
#repetition_penalty=1.1 + t*0.01,
|
#repetition_penalty=1.1 + t*0.01,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
#no_repeat_ngram_size = 5,
|
#no_repeat_ngram_size = 5,
|
||||||
).squeeze()
|
).squeeze()
|
||||||
|
|
||||||
output_str = tokenizer.decode(output[input_ids.shape[1]:])
|
output_str = tokenizer.decode(output[input_ids.shape[1]:])
|
||||||
output_str = ai_prefix + output_str
|
output_str = ai_prefix + output_str
|
||||||
print("output_str:\n", output_str, "\n-----------------------------------------------------------------")
|
print("output_str:\n", output_str, "\n-----------------------------------------------------------------")
|
||||||
|
|
||||||
example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str)
|
example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_result_1(model, tokenizer):
|
def get_result_1(model, tokenizer):
|
||||||
|
|
||||||
test = []
|
test = []
|
||||||
with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f:
|
with open("/mnt/disk2/home/wujianfeng/com/code/code_round4.jsonl", "r") as f:
|
||||||
#test = json.load(f)
|
#test = json.load(f)
|
||||||
for line in f:
|
for line in f:
|
||||||
test.append(json.loads(line))
|
test.append(json.loads(line))
|
||||||
|
|
||||||
all_score = 0
|
all_score = 0
|
||||||
all_num = 0
|
all_num = 0
|
||||||
test_num = 1000
|
test_num = 1000
|
||||||
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
for example in tqdm(test[:]):
|
for example in tqdm(test[:]):
|
||||||
#print(example["question"])
|
#print(example["question"])
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": example["question"]}
|
{"role": "user", "content": example["question"]}
|
||||||
]
|
]
|
||||||
text = tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True
|
add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
example["prompt"] = text
|
example["prompt"] = text
|
||||||
print("text: " , [text])
|
print("text: " , [text])
|
||||||
input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids
|
input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids
|
||||||
|
|
||||||
|
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
#top_p=1.0,
|
#top_p=1.0,
|
||||||
max_new_tokens=600,
|
max_new_tokens=600,
|
||||||
#repetition_penalty=1.1 + t*0.01,
|
#repetition_penalty=1.1 + t*0.01,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
#no_repeat_ngram_size = 5,
|
#no_repeat_ngram_size = 5,
|
||||||
).squeeze()
|
).squeeze()
|
||||||
output_str = tokenizer.decode(output[input_ids.shape[1]:])
|
output_str = tokenizer.decode(output[input_ids.shape[1]:])
|
||||||
print("output_str:\n", output_str, "\n-----------------------------------------------------------------")
|
print("output_str:\n", output_str, "\n-----------------------------------------------------------------")
|
||||||
|
|
||||||
example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str)
|
example["raw_outputs"] = output_str#re.findall(r'```python(.*?)```', output_str)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
answers = {}
|
answers = {}
|
||||||
for model_path in [
|
|
||||||
"/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_1/checkpoint-600",
|
for model_path in [
|
||||||
"/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new/checkpoint-600/",
|
"/mnt/disk2/home/wujianfeng/LLaMA-Factory/all/TACO/",
|
||||||
|
"/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_2/CodeNet4Repair/",
|
||||||
]:
|
"/mnt/disk2/home/wujianfeng/LLaMA-Factory/all_new_1/CodeExercise-Python-27k/",
|
||||||
print("model_path: ", model_path)
|
]:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
print("model_path: ", model_path)
|
||||||
model_path,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype="auto",
|
model_path,
|
||||||
device_map=device,
|
torch_dtype="auto",
|
||||||
trust_remote_code=True,
|
device_map=device,
|
||||||
)
|
trust_remote_code=True,
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
test = get_result(model, tokenizer)
|
|
||||||
test, score = exec_code(test)
|
test = get_result(model, tokenizer)
|
||||||
answers[score] = test
|
test, score = exec_code(test)
|
||||||
|
answers[score] = test
|
||||||
|
|
||||||
test = get_result_1(model, tokenizer)
|
|
||||||
test, score = exec_code(test)
|
test = get_result_1(model, tokenizer)
|
||||||
answers[score] = test
|
test, score = exec_code(test)
|
||||||
|
answers[score] = test
|
||||||
|
|
||||||
|
'''
|
||||||
answers = list(dict(sorted(answers.items())).values())
|
import os
|
||||||
print("answers: ", answers)
|
for path in os.listdir("./"):
|
||||||
right = 0
|
if "home-wujianfeng" in path:
|
||||||
jiuge_right = 0
|
with open(path, "r") as f:
|
||||||
merge = []
|
test = json.load(f)
|
||||||
for i in range(len(answers)):
|
answers[float(path.split(".")[-2].split("-")[-1])] = test
|
||||||
#for i in range(2):
|
'''
|
||||||
flag = 0
|
|
||||||
for answer in answers:
|
answers = list(dict(sorted(answers.items())).values())
|
||||||
if answer[i]["result"] == "True":
|
print("answers: ", answers)
|
||||||
right += 1
|
right = 0
|
||||||
jiuge_right += 1
|
jiuge_right = 0
|
||||||
flag = 1
|
merge = []
|
||||||
merge.append(answer[i])
|
for i in range(len(answers[0])):
|
||||||
break
|
#for i in range(2):
|
||||||
|
flag = 0
|
||||||
if flag == 0:
|
for answer in answers:
|
||||||
merge.append(answers[0][i])
|
if answer[i]["result"] == "True":
|
||||||
|
right += 1
|
||||||
|
jiuge_right += 1
|
||||||
|
flag = 1
|
||||||
print(right / len(answers), jiuge_right / len(answers))
|
merge.append(answer[i])
|
||||||
with open("wjf_jiuge.jsonl", "w") as f:
|
break
|
||||||
for item in merge:
|
|
||||||
item.pop("result")
|
if flag == 0:
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
merge.append(answers[0][i])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print(right / len(answers[0]), jiuge_right / len(answers[0]))
|
||||||
|
with open("wjf_jiuge.jsonl", "w") as f:
|
||||||
|
for item in merge:
|
||||||
|
item.pop("result")
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue