From 297fb8ead3daf154152d9826b49bb4d769fbaaa9 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 23:39:31 +0800 Subject: [PATCH] support new special token #3420 --- src/llmtuner/hparams/data_args.py | 4 ++-- src/llmtuner/hparams/generating_args.py | 4 ++-- src/llmtuner/hparams/model_args.py | 7 +++++++ src/llmtuner/hparams/parser.py | 6 +++++- src/llmtuner/model/adapter.py | 11 +++++++++++ src/llmtuner/model/loader.py | 12 ++++++++++++ src/llmtuner/model/utils/embedding.py | 6 ++++-- src/llmtuner/model/utils/rope.py | 4 ++++ 8 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index f5f75c77..1e0cd08c 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -26,11 +26,11 @@ class DataArguments: ) cutoff_len: int = field( default=1024, - metadata={"help": "The cutoff length of the model inputs after tokenization."}, + metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, ) reserved_label_len: int = field( default=1, - metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, + metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."}, ) train_on_prompt: bool = field( default=False, diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index 70dabb3e..e792c003 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -31,11 +31,11 @@ class GeneratingArguments: metadata={"help": "Number of beams for beam search. 1 means no beam search."}, ) max_length: int = field( - default=512, + default=1024, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, ) max_new_tokens: int = field( - default=512, + default=1024, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, ) repetition_penalty: float = field( diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index eb6366d9..b60492a0 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -33,6 +33,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, ) + new_special_tokens: Optional[str] = field( + default=None, + metadata={"help": "Special tokens to be added into the tokenizer."}, + ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, @@ -177,6 +181,9 @@ class ModelArguments: if self.adapter_name_or_path is not None: # support merging multiple lora weights self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + if self.new_special_tokens is not None: # support multiple special tokens + self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 0d286819..a7d0a17f 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -67,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") + if model_args.resize_vocab: + raise ValueError("Cannot resize embedding layers of a quantized model.") + if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: raise ValueError("Cannot create new adapter upon a quantized model.") @@ -199,10 +202,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if ( training_args.do_train and finetuning_args.finetuning_type == "lora" + and model_args.quantization_bit is None and model_args.resize_vocab and finetuning_args.additional_target is None ): - logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.") + logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): logger.warning("We recommend enable `upcast_layernorm` in quantized training.") diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index af58b514..d43e00f0 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -157,6 +157,17 @@ def init_adapter( ): raise ValueError("DoRA is not compatible with PTQ-quantized models.") + if model_args.resize_vocab and finetuning_args.additional_target is None: + input_embeddings = model.get_input_embeddings() + output_embeddings = model.get_output_embeddings() + module_names = set() + for name, module in model.named_modules(): + if module in [input_embeddings, output_embeddings]: + module_names.add(name.split(".")[-1]) + + finetuning_args.additional_target = module_names + logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + peft_kwargs = { "r": finetuning_args.lora_rank, "target_modules": target_modules, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 06405219..54048cc5 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -39,6 +39,8 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": r""" Loads pretrained tokenizer. + + Note: including inplace operation of model_args. """ init_kwargs = _get_init_kwargs(model_args) try: @@ -57,6 +59,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": **init_kwargs, ) + if model_args.new_special_tokens is not None: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=model_args.new_special_tokens), + replace_additional_special_tokens=False, + ) + logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) + if num_added_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning("New tokens have been added, changed `resize_vocab` to True.") + patch_tokenizer(tokenizer) return tokenizer diff --git a/src/llmtuner/model/utils/embedding.py b/src/llmtuner/model/utils/embedding.py index 7759fc0f..357c9cc0 100644 --- a/src/llmtuner/model/utils/embedding.py +++ b/src/llmtuner/model/utils/embedding.py @@ -42,9 +42,11 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken current_embedding_size = model.get_input_embeddings().weight.size(0) if len(tokenizer) > current_embedding_size: + if getattr(model, "quantization_method", None): + raise ValueError("Cannot resize embedding layers of a quantized model.") + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): - logger.warning("Current model does not support resizing token embeddings.") - return + raise ValueError("Current model does not support resizing embedding layers.") model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) with context_maybe_zero3: diff --git a/src/llmtuner/model/utils/rope.py b/src/llmtuner/model/utils/rope.py index 2a4cce7a..9163253b 100644 --- a/src/llmtuner/model/utils/rope.py +++ b/src/llmtuner/model/utils/rope.py @@ -30,6 +30,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ current_max_length = getattr(config, "max_position_embeddings", None) if current_max_length and model_args.model_max_length > current_max_length: + logger.warning( + "Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length) + ) + setattr(config, "max_position_embeddings", model_args.model_max_length) scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) else: logger.warning("Input length is smaller than max length. Consider increase input length.")