parent
516d0ddc66
commit
e8dd38b7fd
|
@ -60,7 +60,7 @@ def load_model(
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
patch_config(config, tokenizer, model_args,finetuning_args, init_kwargs, is_trainable)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
if is_trainable and model_args.use_unsloth:
|
if is_trainable and model_args.use_unsloth:
|
||||||
|
|
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import ModelArguments,FinetuningArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -265,7 +265,6 @@ def patch_config(
|
||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
init_kwargs: Dict[str, Any],
|
init_kwargs: Dict[str, Any],
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -290,8 +289,7 @@ def patch_config(
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||||
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
|
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
|
|
|
@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||||
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.reward_model = reward_model
|
self.reward_model = reward_model
|
||||||
|
self.current_device = get_current_device() # patch for deepspeed training
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
|
|
@ -22,7 +22,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||||
adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
|
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
|
|
|
@ -27,7 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
|
@ -160,10 +159,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||||
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
|
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
||||||
refresh_btn = gr.Button(scale=1)
|
|
||||||
|
|
||||||
refresh_btn.click(
|
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||||
list_adapters,
|
list_adapters,
|
||||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||||
[reward_model],
|
[reward_model],
|
||||||
|
@ -171,9 +169,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
)
|
)
|
||||||
|
|
||||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||||
elem_dict.update(
|
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
|
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
|
@ -746,15 +746,15 @@ LOCALES = {
|
||||||
"reward_model": {
|
"reward_model": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Reward model",
|
"label": "Reward model",
|
||||||
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
|
"info": "Adapter of the reward model for PPO training.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Модель вознаграждения",
|
"label": "Модель вознаграждения",
|
||||||
"info": "Адаптер модели вознаграждения для обучения PPO. (Необходимо обновить адаптеры)",
|
"info": "Адаптер модели вознаграждения для обучения PPO.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "奖励模型",
|
"label": "奖励模型",
|
||||||
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)",
|
"info": "PPO 训练中奖励模型的适配器路径。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"galore_tab": {
|
"galore_tab": {
|
||||||
|
@ -1199,7 +1199,12 @@ ALERTS = {
|
||||||
"err_no_adapter": {
|
"err_no_adapter": {
|
||||||
"en": "Please select an adapter.",
|
"en": "Please select an adapter.",
|
||||||
"ru": "Пожалуйста, выберите адаптер.",
|
"ru": "Пожалуйста, выберите адаптер.",
|
||||||
"zh": "请选择一个适配器。",
|
"zh": "请选择适配器。",
|
||||||
|
},
|
||||||
|
"err_no_reward_model": {
|
||||||
|
"en": "Please select a reward model.",
|
||||||
|
"ru": "Пожалуйста, выберите модель вознаграждения.",
|
||||||
|
"zh": "请选择奖励模型。",
|
||||||
},
|
},
|
||||||
"err_no_export_dir": {
|
"err_no_export_dir": {
|
||||||
"en": "Please provide export dir.",
|
"en": "Please provide export dir.",
|
||||||
|
|
|
@ -52,6 +52,8 @@ class Runner:
|
||||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||||
|
stage = TRAINING_STAGES[get("train.training_stage")]
|
||||||
|
reward_model = get("train.reward_model")
|
||||||
|
|
||||||
if self.running:
|
if self.running:
|
||||||
return ALERTS["err_conflict"][lang]
|
return ALERTS["err_conflict"][lang]
|
||||||
|
@ -65,6 +67,9 @@ class Runner:
|
||||||
if len(dataset) == 0:
|
if len(dataset) == 0:
|
||||||
return ALERTS["err_no_dataset"][lang]
|
return ALERTS["err_no_dataset"][lang]
|
||||||
|
|
||||||
|
if stage == "ppo" and not reward_model:
|
||||||
|
return ALERTS["err_no_reward_model"][lang]
|
||||||
|
|
||||||
if not from_preview and self.demo_mode:
|
if not from_preview and self.demo_mode:
|
||||||
return ALERTS["err_demo"][lang]
|
return ALERTS["err_demo"][lang]
|
||||||
|
|
||||||
|
@ -163,8 +168,11 @@ class Runner:
|
||||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||||
|
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = get_save_dir(
|
args["reward_model"] = ",".join(
|
||||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
[
|
||||||
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||||
|
for adapter in get("train.reward_model")
|
||||||
|
]
|
||||||
)
|
)
|
||||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue