Merge pull request #4878 from ly863/main

Train the last turing conversation.
This commit is contained in:
hoshi-hiyouga 2024-07-18 22:03:41 +08:00 committed by GitHub
commit 2516763d69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 30 additions and 3 deletions

View File

@ -70,7 +70,11 @@ def _encode_supervised_example(
source_mask = [IGNORE_INDEX] * source_len source_mask = [IGNORE_INDEX] * source_len
input_ids += source_ids + target_ids input_ids += source_ids + target_ids
labels += source_mask + target_ids
if data_args.train_last_turn_only and turn_idx != len(encoded_pairs) - 1:
labels += source_mask + [IGNORE_INDEX] * len(target_ids)
else:
labels += source_mask + target_ids
if template.efficient_eos: if template.efficient_eos:
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]

View File

@ -41,6 +41,10 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
train_last_turn_only: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to train the last turn only."},
)
cutoff_len: int = field( cutoff_len: int = field(
default=1024, default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},

View File

@ -162,6 +162,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
# Check arguments # Check arguments
if finetuning_args.stage != "pt" and data_args.template is None: if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage == "pt" and data_args.train_last_turn_only:
raise ValueError("PT stage does not support `train_last_turn_only`.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate: if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.") raise ValueError("`predict_with_generate` cannot be set as True except SFT.")

View File

@ -44,10 +44,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
train_last_turn_only = gr.Checkbox()
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset}) input_elems.update({training_stage, dataset_dir, dataset,train_last_turn_only})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset,train_last_turn_only=train_last_turn_only, **preview_elems))
with gr.Row(): with gr.Row():
learning_rate = gr.Textbox(value="5e-5") learning_rate = gr.Textbox(value="5e-5")

View File

@ -536,6 +536,20 @@ LOCALES = {
"info": "更改分词器词表和嵌入层的大小。", "info": "更改分词器词表和嵌入层的大小。",
}, },
}, },
"train_last_turn_only": {
"en": {
"label": "Train last turn only",
"info": "Train the model with the last turn only in multi turn.",
},
"ru": {
"label": "Обучать только последний поворот",
"info": "Обучать модель только последним поворотом в многоповоротном диалоге.",
},
"zh": {
"label": "仅最后一轮参与训练",
"info": "多轮对话仅使用最后一轮计算loss。",
},
},
"use_llama_pro": { "use_llama_pro": {
"en": { "en": {
"label": "Enable LLaMA Pro", "label": "Enable LLaMA Pro",

View File

@ -125,6 +125,7 @@ class Runner:
visual_inputs=get("top.visual_inputs"), visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"), dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")), dataset=",".join(get("train.dataset")),
train_last_turn_only=get("train.train_last_turn_only"),
cutoff_len=get("train.cutoff_len"), cutoff_len=get("train.cutoff_len"),
learning_rate=float(get("train.learning_rate")), learning_rate=float(get("train.learning_rate")),
num_train_epochs=float(get("train.num_train_epochs")), num_train_epochs=float(get("train.num_train_epochs")),