diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 169b31d3..22039920 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -63,18 +63,19 @@ def _encode_supervised_example( total_length += source_len + target_len if data_args.train_on_prompt: - source_mask = source_ids + source_label = source_ids elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) + source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) else: - source_mask = [IGNORE_INDEX] * source_len + source_label = [IGNORE_INDEX] * source_len + + if data_args.mask_history and turn_idx != len(encoded_pairs) - 1: + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids input_ids += source_ids + target_ids - - if data_args.train_last_turn_only and turn_idx != len(encoded_pairs) - 1: - labels += source_mask + [IGNORE_INDEX] * len(target_ids) - else: - labels += source_mask + target_ids + labels += source_label + target_label if template.efficient_eos: input_ids += [tokenizer.eos_token_id] diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 10630019..cd762c75 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -41,17 +41,17 @@ class DataArguments: default="data", metadata={"help": "Path to the folder containing the datasets."}, ) - train_last_turn_only: Optional[bool] = field( - default=False, - metadata={"help": "Whether or not to train the last turn only."}, - ) cutoff_len: int = field( default=1024, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, ) train_on_prompt: bool = field( default=False, - metadata={"help": "Whether to disable the mask on the prompt or not."}, + metadata={"help": "Whether or not to disable the mask on the prompt."}, + ) + mask_history: bool = field( + default=False, + metadata={"help": "Whether or not to mask the history and train on the last turn only."}, ) streaming: bool = field( default=False, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 65e26e6a..b3c87b76 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -162,9 +162,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: # Check arguments if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.") - - if finetuning_args.stage == "pt" and data_args.train_last_turn_only: - raise ValueError("PT stage does not support `train_last_turn_only`.") if finetuning_args.stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 30a929c3..e5dc92c3 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -44,11 +44,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) - train_last_turn_only = gr.Checkbox() preview_elems = create_preview_box(dataset_dir, dataset) - input_elems.update({training_stage, dataset_dir, dataset,train_last_turn_only}) - elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset,train_last_turn_only=train_last_turn_only, **preview_elems)) + input_elems.update({training_stage, dataset_dir, dataset}) + elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) with gr.Row(): learning_rate = gr.Textbox(value="5e-5") @@ -99,6 +98,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: packing = gr.Checkbox() neat_packing = gr.Checkbox() + with gr.Column(): + train_on_prompt = gr.Checkbox() + mask_history = gr.Checkbox() + with gr.Column(): resize_vocab = gr.Checkbox() use_llama_pro = gr.Checkbox() @@ -116,6 +119,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: optim, packing, neat_packing, + train_on_prompt, + mask_history, resize_vocab, use_llama_pro, shift_attn, @@ -132,6 +137,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: optim=optim, packing=packing, neat_packing=neat_packing, + train_on_prompt=train_on_prompt, + mask_history=mask_history, resize_vocab=resize_vocab, use_llama_pro=use_llama_pro, shift_attn=shift_attn, diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index d25f4d38..1ca152c4 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -32,7 +32,7 @@ if is_gradio_available(): import gradio as gr -def create_ui(demo_mode: bool = False) -> gr.Blocks: +def create_ui(demo_mode: bool = False) -> "gr.Blocks": engine = Engine(demo_mode=demo_mode, pure_chat=False) with gr.Blocks(title="LLaMA Board", css=CSS) as demo: @@ -67,7 +67,7 @@ def create_ui(demo_mode: bool = False) -> gr.Blocks: return demo -def create_web_demo() -> gr.Blocks: +def create_web_demo() -> "gr.Blocks": engine = Engine(pure_chat=True) with gr.Blocks(title="Web Demo", css=CSS) as demo: diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 2211a37f..b1f2a802 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -522,6 +522,34 @@ LOCALES = { "info": "避免打包后的序列产生交叉注意力。", }, }, + "train_on_prompt": { + "en": { + "label": "Train on prompt", + "info": "Disable the label mask on the prompt (only for SFT).", + }, + "ru": { + "label": "Тренировка на подсказке", + "info": "Отключить маску меток на подсказке (только для SFT).", + }, + "zh": { + "label": "学习提示词", + "info": "不在提示词的部分添加掩码(仅适用于 SFT)。", + }, + }, + "mask_history": { + "en": { + "label": "Mask history", + "info": "Train on the last turn only (only for SFT).", + }, + "ru": { + "label": "История масок", + "info": "Тренироваться только на последнем шаге (только для SFT).", + }, + "zh": { + "label": "不学习历史对话", + "info": "仅学习最后一轮对话(仅适用于 SFT)。", + }, + }, "resize_vocab": { "en": { "label": "Resize token embeddings", @@ -536,20 +564,6 @@ LOCALES = { "info": "更改分词器词表和嵌入层的大小。", }, }, - "train_last_turn_only": { - "en": { - "label": "Train last turn only", - "info": "Train the model with the last turn only in multi turn.", - }, - "ru": { - "label": "Обучать только последний поворот", - "info": "Обучать модель только последним поворотом в многоповоротном диалоге.", - }, - "zh": { - "label": "仅最后一轮参与训练", - "info": "多轮对话仅使用最后一轮计算loss。", - }, - }, "use_llama_pro": { "en": { "label": "Enable LLaMA Pro", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 6a766abe..9bc69679 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -125,7 +125,6 @@ class Runner: visual_inputs=get("top.visual_inputs"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")), - train_last_turn_only=get("train.train_last_turn_only"), cutoff_len=get("train.cutoff_len"), learning_rate=float(get("train.learning_rate")), num_train_epochs=float(get("train.num_train_epochs")), @@ -141,6 +140,8 @@ class Runner: optim=get("train.optim"), packing=get("train.packing") or get("train.neat_packing"), neat_packing=get("train.neat_packing"), + train_on_prompt=get("train.train_on_prompt"), + mask_history=get("train.mask_history"), resize_vocab=get("train.resize_vocab"), use_llama_pro=get("train.use_llama_pro"), shift_attn=get("train.shift_attn"),