This commit is contained in:
hiyouga 2024-03-13 12:42:03 +08:00
parent b9f87cdc11
commit 0b4a5bf509
1 changed files with 6 additions and 5 deletions

View File

@ -52,8 +52,6 @@ class Runner:
get = lambda name: data[self.manager.get_elem_by_name(name)] get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset") dataset = get("train.dataset") if do_train else get("eval.dataset")
stage = TRAINING_STAGES[get("train.training_stage")]
reward_model = get("train.reward_model")
if self.running: if self.running:
return ALERTS["err_conflict"][lang] return ALERTS["err_conflict"][lang]
@ -67,15 +65,18 @@ class Runner:
if len(dataset) == 0: if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang] return ALERTS["err_no_dataset"][lang]
if stage == "ppo" and not reward_model:
return ALERTS["err_no_reward_model"][lang]
if not from_preview and self.demo_mode: if not from_preview and self.demo_mode:
return ALERTS["err_demo"][lang] return ALERTS["err_demo"][lang]
if not from_preview and get_device_count() > 1: if not from_preview and get_device_count() > 1:
return ALERTS["err_device_count"][lang] return ALERTS["err_device_count"][lang]
if do_train:
stage = TRAINING_STAGES[get("train.training_stage")]
reward_model = get("train.reward_model")
if stage == "ppo" and not reward_model:
return ALERTS["err_no_reward_model"][lang]
if not from_preview and not is_torch_cuda_available(): if not from_preview and not is_torch_cuda_available():
gr.Warning(ALERTS["warn_no_cuda"][lang]) gr.Warning(ALERTS["warn_no_cuda"][lang])