fix bug at inference

This commit is contained in:
hiyouga 2023-05-31 18:11:53 +08:00
parent 38ca429228
commit fd709eacff
2 changed files with 8 additions and 8 deletions

View File

@ -29,8 +29,8 @@ def main():
return prompt return prompt
def predict(query, history: list): def predict(query, history: list):
inputs = tokenizer([format_example(query)], return_tensors="pt") input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
inputs = inputs.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,
"top_p": 0.9, "top_p": 0.9,
@ -41,8 +41,8 @@ def main():
"repetition_penalty": 1.5 "repetition_penalty": 1.5
} }
with torch.no_grad(): with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs) generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] outputs = generation_output.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(query, response)] history = history + [(query, response)]
return response, history return response, history

View File

@ -83,8 +83,8 @@ def format_example(query):
def predict(input, chatbot, max_length, top_p, temperature, history): def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), "")) chatbot.append((parse_text(input), ""))
inputs = tokenizer([format_example(input)], return_tensors="pt") input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"]
inputs = inputs.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,
"top_p": top_p, "top_p": top_p,
@ -94,8 +94,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
"repetition_penalty": 1.0 "repetition_penalty": 1.0
} }
with torch.no_grad(): with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs) generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] outputs = generation_output.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(input, response)] history = history + [(input, response)]
chatbot[-1] = (parse_text(input), parse_text(response)) chatbot[-1] = (parse_text(input), parse_text(response))