This commit is contained in:
hiyouga 2024-03-09 02:01:26 +08:00
parent 516d0ddc66
commit e8dd38b7fd
7 changed files with 28 additions and 20 deletions

View File

@ -60,7 +60,7 @@ def load_model(
"""
init_kwargs = _get_init_kwargs(model_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_config(config, tokenizer, model_args,finetuning_args, init_kwargs, is_trainable)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments,FinetuningArguments
from ..hparams import ModelArguments
logger = get_logger(__name__)
@ -265,7 +265,6 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@ -290,7 +289,6 @@ def patch_config(
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if "device_map" not in init_kwargs: # quant models cannot use auto device map
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}

View File

@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@ -49,6 +49,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,

View File

@ -22,7 +22,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:

View File

@ -27,7 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
@ -160,10 +159,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
refresh_btn = gr.Button(scale=1)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
refresh_btn.click(
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
@ -171,9 +169,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
)
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
with gr.Row():

View File

@ -746,15 +746,15 @@ LOCALES = {
"reward_model": {
"en": {
"label": "Reward model",
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
"info": "Adapter of the reward model for PPO training.",
},
"ru": {
"label": "Модель вознаграждения",
"info": "Адаптер модели вознаграждения для обучения PPO. (Необходимо обновить адаптеры)",
"info": "Адаптер модели вознаграждения для обучения PPO.",
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)",
"info": "PPO 训练中奖励模型的适配器路径。",
},
},
"galore_tab": {
@ -1199,7 +1199,12 @@ ALERTS = {
"err_no_adapter": {
"en": "Please select an adapter.",
"ru": "Пожалуйста, выберите адаптер.",
"zh": "请选择一个适配器。",
"zh": "请选择适配器。",
},
"err_no_reward_model": {
"en": "Please select a reward model.",
"ru": "Пожалуйста, выберите модель вознаграждения.",
"zh": "请选择奖励模型。",
},
"err_no_export_dir": {
"en": "Please provide export dir.",

View File

@ -52,6 +52,8 @@ class Runner:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
stage = TRAINING_STAGES[get("train.training_stage")]
reward_model = get("train.reward_model")
if self.running:
return ALERTS["err_conflict"][lang]
@ -65,6 +67,9 @@ class Runner:
if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang]
if stage == "ppo" and not reward_model:
return ALERTS["err_no_reward_model"][lang]
if not from_preview and self.demo_mode:
return ALERTS["err_demo"][lang]
@ -163,8 +168,11 @@ class Runner:
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
args["reward_model"] = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("train.reward_model")
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"