follow #5115
This commit is contained in:
parent
51542cb15f
commit
c87023d539
|
@ -7,7 +7,7 @@ do_predict: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
eval_dataset: alpaca_en_demo
|
eval_dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
|
|
|
@ -206,8 +206,6 @@ def get_dataset(
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
if stage!="sft" and data_args.mask_history:
|
|
||||||
raise ValueError("`Train on the last turn only` is only valid for sft training.")
|
|
||||||
|
|
||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
|
|
|
@ -53,8 +53,11 @@ def _encode_supervised_example(
|
||||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||||
|
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history)
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||||
total_length = 1 if template.efficient_eos else 0
|
total_length = 1 if template.efficient_eos else 0
|
||||||
|
if mask_history:
|
||||||
|
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||||
if total_length >= cutoff_len:
|
if total_length >= cutoff_len:
|
||||||
break
|
break
|
||||||
|
@ -66,17 +69,20 @@ def _encode_supervised_example(
|
||||||
|
|
||||||
if train_on_prompt:
|
if train_on_prompt:
|
||||||
source_label = source_ids
|
source_label = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif template.efficient_eos:
|
||||||
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||||
else:
|
else:
|
||||||
source_label = [IGNORE_INDEX] * source_len
|
source_label = [IGNORE_INDEX] * source_len
|
||||||
|
|
||||||
if mask_history:
|
if mask_history and turn_idx != 0: # train on the last turn only
|
||||||
target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len
|
target_label = [IGNORE_INDEX] * target_len
|
||||||
|
else:
|
||||||
|
target_label = target_ids
|
||||||
|
|
||||||
|
if mask_history: # reversed sequences
|
||||||
input_ids = source_ids + target_ids + input_ids
|
input_ids = source_ids + target_ids + input_ids
|
||||||
labels = source_label + target_label + labels
|
labels = source_label + target_label + labels
|
||||||
else:
|
else:
|
||||||
target_label = target_ids
|
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
labels += source_label + target_label
|
labels += source_label + target_label
|
||||||
|
|
||||||
|
|
|
@ -69,16 +69,12 @@ class Template:
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
mask_history: bool = False,
|
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
if not mask_history:
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
|
||||||
else:
|
|
||||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)]
|
|
||||||
|
|
||||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -594,10 +590,10 @@ _register_template(
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by DeepSeek Company, and you only answer questions related to computer science. "
|
||||||
"For politically sensitive questions, security and privacy issues, "
|
"For politically sensitive questions, security and privacy issues, "
|
||||||
"and other non-computer science questions, you will refuse to answer\n"
|
"and other non-computer science questions, you will refuse to answer.\n"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -143,4 +143,4 @@ class DataArguments:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
if self.mask_history and self.train_on_prompt:
|
if self.mask_history and self.train_on_prompt:
|
||||||
raise ValueError("`Train on the last turn only` does not support `train_on_prompt`.")
|
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
|
||||||
|
|
|
@ -163,11 +163,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
if finetuning_args.stage != "sft":
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
if training_args.predict_with_generate:
|
||||||
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if finetuning_args.stage != "sft" and data_args.neat_packing:
|
if data_args.neat_packing:
|
||||||
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||||
|
|
||||||
|
if data_args.train_on_prompt or data_args.mask_history:
|
||||||
|
raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
@ -175,21 +179,18 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
||||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
if finetuning_args.stage == "ppo":
|
||||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
if not training_args.do_train:
|
||||||
|
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support lora reward model.")
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
if (
|
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||||
finetuning_args.stage == "ppo"
|
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||||
and training_args.report_to
|
|
||||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
|
||||||
):
|
|
||||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
|
||||||
|
|
||||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||||
|
|
Loading…
Reference in New Issue