diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index b8eafd00..b67781d7 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -22,6 +22,9 @@ def preprocess_dataset( column_names = list(next(iter(dataset)).keys()) template = get_template_and_fix_tokenizer(data_args.template, tokenizer) + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: for i in range(len(examples["prompt"])): query, response = examples["prompt"][i], examples["response"][i] @@ -32,13 +35,12 @@ def preprocess_dataset( def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build grouped texts with format `X1 X2 X3 ...` - if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): - kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen) + if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") else: kwargs = dict(add_special_tokens=True) - if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"): - setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer + if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer setattr(tokenizer, "add_eos_token", True) tokenized_examples = tokenizer(examples["prompt"], **kwargs) @@ -74,7 +76,9 @@ def preprocess_dataset( if len(target_ids) > max_target_len: target_ids = target_ids[:max_target_len] - if turn_idx != 0 and template.efficient_eos: + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) else: source_mask = [IGNORE_INDEX] * len(source_ids) @@ -97,15 +101,17 @@ def preprocess_dataset( return model_inputs def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build inputs with format ` X Y ` and labels with format ` X Y ` - # we do not mask the inputs in packed training. + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} input_ids, labels = [], [] for query, response, history, system in construct_example(examples): for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( tokenizer, query, response, history, system )): - if turn_idx != 0 and template.efficient_eos: + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) else: source_mask = [IGNORE_INDEX] * len(source_ids) @@ -229,5 +235,9 @@ def preprocess_dataset( **kwargs ) - print_function(next(iter(dataset))) + try: + print_function(next(iter(dataset))) + except StopIteration: + raise ValueError("Empty dataset!") + return dataset diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 44797ba1..9d432c56 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -31,7 +31,7 @@ class DataArguments: metadata={"help": "Which template to use for constructing prompts in training and inference."} ) dataset: Optional[str] = field( - default="alpaca_en", + default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} ) dataset_dir: Optional[str] = field( @@ -46,13 +46,17 @@ class DataArguments: default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."} ) + train_on_prompt: Optional[bool] = field( + default=False, + metadata={"help": "Whether to disable the mask on the prompt or not."} + ) streaming: Optional[bool] = field( default=False, - metadata={"help": "Enable streaming mode."} + metadata={"help": "Enable dataset streaming."} ) buffer_size: Optional[int] = field( default=16384, - metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."} + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} ) mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( default="concat", @@ -95,10 +99,20 @@ class DataArguments: metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} ) + def __post_init__(self): + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: + raise ValueError("Streaming mode should have an integer val size.") + + if self.streaming and self.max_samples is not None: + raise ValueError("`max_samples` is incompatible with `streaming`.") + def init_for_training(self): # support mixing multiple datasets - dataset_names = [ds.strip() for ds in self.dataset.split(",")] - with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: - dataset_info = json.load(f) + dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] + try: + with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: + dataset_info = json.load(f) + except Exception: + dataset_info = None prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index a3d6d917..f3ebfd39 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -1,4 +1,3 @@ -import torch from typing import Literal, Optional from dataclasses import dataclass, field @@ -19,6 +18,10 @@ class ModelArguments: default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} ) + split_special_tokens: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} + ) use_auth_token: Optional[bool] = field( default=False, metadata={"help": "Will use the token generated when running `huggingface-cli login`."} @@ -76,6 +79,9 @@ class ModelArguments: self.compute_dtype = None self.model_max_length = None + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 118138c2..56cf02eb 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -122,9 +122,6 @@ def get_train_args( if general_args.stage == "ppo" and model_args.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") - if general_args.stage == "ppo" and training_args.deepspeed is not None: - raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.") - if general_args.stage == "ppo" and data_args.streaming: raise ValueError("Streaming mode does not suppport PPO training currently.") @@ -134,9 +131,6 @@ def get_train_args( if training_args.max_steps == -1 and data_args.streaming: raise ValueError("Please specify `max_steps` in streaming mode.") - if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming: - raise ValueError("Streaming mode should have an integer val size.") - if training_args.do_train and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True while training.") @@ -166,11 +160,6 @@ def get_train_args( if (not training_args.do_train) and model_args.quantization_bit is not None: logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") - # postprocess data_args - if data_args.max_samples is not None and data_args.streaming: - logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") - data_args.max_samples = None - # postprocess training_args if ( training_args.local_rank != -1