From 3ae735ffe8acbe1df05324daa9e18ec37f33d594 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 8 Jan 2024 21:42:25 +0800 Subject: [PATCH] fix #2125 --- requirements.txt | 2 +- src/llmtuner/webui/runner.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2af2fd6c..ce3c92a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ peft>=0.7.0 trl>=0.7.6 gradio>=3.38.0,<4.0.0 scipy +einops sentencepiece protobuf tiktoken @@ -17,4 +18,3 @@ pydantic fastapi sse-starlette matplotlib -einops diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 483d709d..a6863c36 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -118,13 +118,13 @@ class Runner: logging_steps=get("train.logging_steps"), save_steps=get("train.save_steps"), warmup_steps=get("train.warmup_steps"), - neftune_noise_alpha=get("train.neftune_alpha"), + neftune_noise_alpha=get("train.neftune_alpha") or None, train_on_prompt=get("train.train_on_prompt"), upcast_layernorm=get("train.upcast_layernorm"), lora_rank=get("train.lora_rank"), lora_dropout=get("train.lora_dropout"), lora_target=get("train.lora_target") or get_module(get("top.model_name")), - additional_target=get("train.additional_target") if get("train.additional_target") else None, + additional_target=get("train.additional_target") or None, create_new_adapter=get("train.create_new_adapter"), output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) ) @@ -164,7 +164,6 @@ class Runner: args = dict( stage="sft", - do_eval=True, model_name_or_path=get("top.model_path"), adapter_name_or_path=adapter_name_or_path, cache_dir=user_config.get("cache_dir", None), @@ -187,8 +186,9 @@ class Runner: ) if get("eval.predict"): - args.pop("do_eval", None) args["do_predict"] = True + else: + args["do_eval"] = True return args