diff --git a/README.md b/README.md index 0dd3d56e..513c4160 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - | -| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan | +| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | diff --git a/README_zh.md b/README_zh.md index 90079ead..ad3d9bd4 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,7 +55,7 @@ | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - | -| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan | +| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 0d22bb5a..b6d71fcf 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -49,7 +49,7 @@ class ChatModel: top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], - eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)), + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, pad_token_id=self.tokenizer.pad_token_id, logits_processor=get_logits_processor() )) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 7ef4da7f..6c86a166 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -63,7 +63,9 @@ def preprocess_dataset( for query, response, history, system in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, query, response, history, system + )): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length: @@ -72,8 +74,17 @@ def preprocess_dataset( if len(input_ids) + len(source_ids) + len(target_ids) > max_length: break + if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + input_ids += source_ids + target_ids - labels += [IGNORE_INDEX] * len(source_ids) + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index e479fa76..f58399da 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -20,6 +20,7 @@ class Template: sep: List[Union[str, Dict[str, str]]] stop_words: List[str] use_history: bool + efficient_eos: bool def encode_oneturn( self, @@ -74,19 +75,19 @@ class Template: self, tokenizer: "PreTrainedTokenizer" ) -> Tuple[List[int], List[int]]: - if ( - tokenizer.bos_token_id is not None - and getattr(tokenizer, "add_bos_token", True) - ): # baichuan-13b has no bos token + if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): bos_ids = [tokenizer.bos_token_id] - else: - bos_ids = [] # bos token is optional + else: # baichuan, qwen and gpt2 models has no bos token + bos_ids = [] - if tokenizer.eos_token_id is not None: - eos_ids = [tokenizer.eos_token_id] - else: + if tokenizer.eos_token_id is None: raise ValueError("EOS token is required.") + if self.efficient_eos: # used in baichuan, qwen and gpt2 models + eos_ids = [] + else: + eos_ids = [tokenizer.eos_token_id] + return bos_ids, eos_ids def _encode( @@ -186,7 +187,8 @@ def register_template( system: str, sep: List[Union[str, Dict[str, str]]], stop_words: Optional[List[str]] = [], - use_history: Optional[bool] = True + use_history: Optional[bool] = True, + efficient_eos: Optional[bool] = False ) -> None: template_class = Llama2Template if "llama2" in name else Template templates[name] = template_class( @@ -195,7 +197,8 @@ def register_template( system=system, sep=sep, stop_words=stop_words, - use_history=use_history + use_history=use_history, + efficient_eos=efficient_eos ) @@ -206,15 +209,6 @@ def get_template_and_fix_tokenizer( template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) - additional_special_tokens = template.stop_words - if len(template.stop_words): # inplace method - if tokenizer.eos_token_id is not None: - additional_special_tokens.append(tokenizer.eos_token) - - tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token - additional_special_tokens.pop(0) - logger.info("Replace eos token: {}".format(tokenizer.eos_token)) - if tokenizer.eos_token_id is None: tokenizer.eos_token = "<|endoftext|>" logger.info("Add eos token: {}".format(tokenizer.eos_token)) @@ -227,7 +221,7 @@ def get_template_and_fix_tokenizer( logger.info("Add pad token: {}".format(tokenizer.pad_token)) tokenizer.add_special_tokens( - dict(additional_special_tokens=additional_special_tokens), + dict(additional_special_tokens=template.stop_words), replace_additional_special_tokens=False ) return template @@ -466,18 +460,18 @@ register_template( ], system="", sep=[ + {"token": ""}, "\n" ], stop_words=[ - "", # internlm cannot replace eos token "" - ] + ], + efficient_eos=True ) r""" Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat -Used for training and inference of the fine-tuned models. """ register_template( name="baichuan", @@ -487,39 +481,17 @@ register_template( prompt=[ {"token": ""}, # user token "{{query}}", - {"token": ""} # assistant token - ], - system="", - sep=[] -) - - -r""" -Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat -Used for inference of the original model. -""" -register_template( - name="baichuan_eval", - prefix=[ - "{{system}}", - {"token": ""} # user token - ], - prompt=[ - "{{query}}", - {"token": ""} # assistant token + {"token": ""} # assistant token ], system="", sep=[], - stop_words=[ - "" # user token - ] + efficient_eos=True ) r""" Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat -Used for training and inference of the fine-tuned models. """ register_template( name="baichuan2", @@ -529,33 +501,11 @@ register_template( prompt=[ {"token": ""}, # user token "{{query}}", - {"token": ""} # assistant token - ], - system="", - sep=[] -) - - -r""" -Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat - https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat -Used for inference of the original model. -""" -register_template( - name="baichuan2_eval", - prefix=[ - "{{system}}", - {"token": ""} # user token - ], - prompt=[ - "{{query}}", - {"token": ""} # assistant token + {"token": ""} # assistant token ], system="", sep=[], - stop_words=[ - "" # user token - ] + efficient_eos=True ) @@ -568,7 +518,6 @@ register_template( prefix=[ {"token": "<|system|>"}, "\n{{system}}", - {"token": "<|end|>"} ], prompt=[ {"token": "<|user|>"}, @@ -579,11 +528,13 @@ register_template( ], system="", sep=[ + {"token": "<|end|>"}, "\n" ], stop_words=[ "<|end|>" - ] + ], + efficient_eos=True ) @@ -594,8 +545,7 @@ register_template( name="chatml", prefix=[ {"token": "<|im_start|>"}, - "system\n{{system}}", - {"token": "<|im_end|>"} + "system\n{{system}}" ], prompt=[ {"token": "<|im_start|>"}, @@ -607,11 +557,13 @@ register_template( ], system="You are a helpful assistant.", sep=[ + {"token": "<|im_end|>"}, "\n" ], stop_words=[ "<|im_end|>" - ] + ], + efficient_eos=True ) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index f85dbd8a..f3de6fbf 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -15,9 +15,13 @@ from transformers import ( ) from transformers.utils import check_min_version from transformers.utils.versions import require_version -from transformers.deepspeed import is_deepspeed_zero3_enabled from trl import AutoModelForCausalLMWithValueHead +try: + from transformers.deepspeed import is_deepspeed_zero3_enabled +except ImportError: + from transformers.integrations import is_deepspeed_zero3_enabled + from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.misc import count_parameters, prepare_model_for_training from llmtuner.extras.save_and_load import load_valuehead_params @@ -91,7 +95,7 @@ def load_model_and_tokenizer( setattr(config, "use_logn_attn", True) logger.info("Using dynamic NTK scaling.") - elif hasattr(config, "rope_scaling"): # for LLaMA models + elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0") if is_trainable: diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 59aede98..8e7204c3 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -76,7 +76,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Keyword arguments for `model.generate` gen_kwargs = self.generating_args.to_dict() - gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)) + gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 48cd703a..a6b9a1f0 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -6,7 +6,6 @@ from torch.optim import AdamW from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq from transformers.optimization import get_scheduler -from transformers.utils.versions import require_version from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.extras.ploting import plot_loss diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 511db1ba..a89a7514 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -54,7 +54,7 @@ def run_sft( # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() - gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids)) + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor()