From a38d57ddd7fcbd2eb373e79f7236f8d2411c52d5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 5 Jun 2023 17:58:29 +0800 Subject: [PATCH] fix bug in web demo --- src/utils/peft_trainer.py | 2 ++ src/web_demo.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index f951fd8c..160a1425 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -49,6 +49,8 @@ class LogCallback(TrainerCallback): r""" Event called after logging the last logs. """ + if "loss" not in state.log_history[-1]: + return cur_time = time.time() cur_steps = state.log_history[-1].get("step") elapsed_time = cur_time - self.start_time diff --git a/src/web_demo.py b/src/web_demo.py index 77dd76d7..96bfbcd1 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -12,7 +12,7 @@ from transformers import TextIteratorStreamer from transformers.utils.versions import require_version -require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems +require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) @@ -93,6 +93,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history): input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { + "input_ids": input_ids, "do_sample": True, "top_p": top_p, "temperature": temperature, @@ -107,9 +108,9 @@ def predict(input, chatbot, max_length, top_p, temperature, history): response = "" for new_text in streamer: response += new_text - history = history + [(input, response)] + new_history = history + [(input, response)] chatbot[-1] = (parse_text(input), parse_text(response)) - yield chatbot, history + yield chatbot, new_history def reset_user_input(): @@ -145,4 +146,4 @@ with gr.Blocks() as demo: emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) -demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True) +demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)