fix bug in web demo

This commit is contained in:
hiyouga 2023-06-05 17:58:29 +08:00
parent 56eb99106a
commit a38d57ddd7
2 changed files with 7 additions and 4 deletions

View File

@ -49,6 +49,8 @@ class LogCallback(TrainerCallback):
r"""
Event called after logging the last logs.
"""
if "loss" not in state.log_history[-1]:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time

View File

@ -12,7 +12,7 @@ from transformers import TextIteratorStreamer
from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
model_args, data_args, finetuning_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
@ -93,6 +93,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
@ -107,9 +108,9 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
response = ""
for new_text in streamer:
response += new_text
history = history + [(input, response)]
new_history = history + [(input, response)]
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
yield chatbot, new_history
def reset_user_input():
@ -145,4 +146,4 @@ with gr.Blocks() as demo:
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)