From 0b4a5bf509a6fbf18337a29a6a498f33d0cbca76 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 13 Mar 2024 12:42:03 +0800 Subject: [PATCH] fix #2817 --- src/llmtuner/webui/runner.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 1d5396a9..0cf50f6a 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -52,8 +52,6 @@ class Runner: get = lambda name: data[self.manager.get_elem_by_name(name)] lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") dataset = get("train.dataset") if do_train else get("eval.dataset") - stage = TRAINING_STAGES[get("train.training_stage")] - reward_model = get("train.reward_model") if self.running: return ALERTS["err_conflict"][lang] @@ -67,15 +65,18 @@ class Runner: if len(dataset) == 0: return ALERTS["err_no_dataset"][lang] - if stage == "ppo" and not reward_model: - return ALERTS["err_no_reward_model"][lang] - if not from_preview and self.demo_mode: return ALERTS["err_demo"][lang] if not from_preview and get_device_count() > 1: return ALERTS["err_device_count"][lang] + if do_train: + stage = TRAINING_STAGES[get("train.training_stage")] + reward_model = get("train.reward_model") + if stage == "ppo" and not reward_model: + return ALERTS["err_no_reward_model"][lang] + if not from_preview and not is_torch_cuda_available(): gr.Warning(ALERTS["warn_no_cuda"][lang])