From fd709eacff09d52636541cba009a1ded6aac22dc Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 31 May 2023 18:11:53 +0800 Subject: [PATCH] fix bug at inference --- src/cli_demo.py | 8 ++++---- src/web_demo.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/cli_demo.py b/src/cli_demo.py index fd24e998..441e6aba 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -29,8 +29,8 @@ def main(): return prompt def predict(query, history: list): - inputs = tokenizer([format_example(query)], return_tensors="pt") - inputs = inputs.to(model.device) + input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"] + input_ids = input_ids.to(model.device) gen_kwargs = { "do_sample": True, "top_p": 0.9, @@ -41,8 +41,8 @@ def main(): "repetition_penalty": 1.5 } with torch.no_grad(): - generation_output = model.generate(**inputs, **gen_kwargs) - outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] + generation_output = model.generate(input_ids=input_ids, **gen_kwargs) + outputs = generation_output.tolist()[0][len(input_ids[0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) history = history + [(query, response)] return response, history diff --git a/src/web_demo.py b/src/web_demo.py index 5129ea85..ca766592 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -83,8 +83,8 @@ def format_example(query): def predict(input, chatbot, max_length, top_p, temperature, history): chatbot.append((parse_text(input), "")) - inputs = tokenizer([format_example(input)], return_tensors="pt") - inputs = inputs.to(model.device) + input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"] + input_ids = input_ids.to(model.device) gen_kwargs = { "do_sample": True, "top_p": top_p, @@ -94,8 +94,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history): "repetition_penalty": 1.0 } with torch.no_grad(): - generation_output = model.generate(**inputs, **gen_kwargs) - outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] + generation_output = model.generate(input_ids=input_ids, **gen_kwargs) + outputs = generation_output.tolist()[0][len(input_ids[0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) history = history + [(input, response)] chatbot[-1] = (parse_text(input), parse_text(response))