From dc09d454f285b8584d9017349a9cee3b44eadb72 Mon Sep 17 00:00:00 2001 From: codingma Date: Thu, 1 Aug 2024 13:45:48 +0800 Subject: [PATCH 1/2] support gemma-2-2b --- src/llamafactory/extras/constants.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 45145886..c413c51d 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -531,6 +531,14 @@ register_model_group( "Gemma-1.1-7B-Chat": { DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", }, + "Gemma-2-2B": { + DownloadSource.DEFAULT: "google/gemma-2-2b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b", + }, + "Gemma-2-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-2b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", + }, "Gemma-2-9B": { DownloadSource.DEFAULT: "google/gemma-2-9b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", From b7ca6c8dc14f689d0df16684a6121cc0ec24f8ba Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 5 Aug 2024 23:48:19 +0800 Subject: [PATCH 2/2] fix #5048 --- README.md | 16 ++-- README_zh.md | 18 ++--- requirements.txt | 10 +-- src/llamafactory/__init__.py | 16 ++-- src/llamafactory/extras/constants.py | 8 +- src/llamafactory/extras/misc.py | 10 +-- src/llamafactory/extras/packages.py | 5 ++ .../model/model_utils/attention.py | 2 +- .../model/model_utils/longlora.py | 75 +++++++++++++------ src/llamafactory/model/model_utils/packing.py | 12 ++- src/llamafactory/train/callbacks.py | 5 +- src/llamafactory/webui/interface.py | 2 +- src/llamafactory/webui/locales.py | 1 - 13 files changed, 111 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index 87b0af7c..386177bb 100644 --- a/README.md +++ b/README.md @@ -300,20 +300,20 @@ huggingface-cli login | Mandatory | Minimum | Recommend | | ------------ | ------- | --------- | | python | 3.8 | 3.11 | -| torch | 1.13.1 | 2.3.0 | -| transformers | 4.41.2 | 4.41.2 | -| datasets | 2.16.0 | 2.19.2 | -| accelerate | 0.30.1 | 0.30.1 | -| peft | 0.11.1 | 0.11.1 | -| trl | 0.8.6 | 0.9.4 | +| torch | 1.13.1 | 2.4.0 | +| transformers | 4.41.2 | 4.43.4 | +| datasets | 2.16.0 | 2.20.0 | +| accelerate | 0.30.1 | 0.32.0 | +| peft | 0.11.1 | 0.12.0 | +| trl | 0.8.6 | 0.9.6 | | Optional | Minimum | Recommend | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | | deepspeed | 0.10.0 | 0.14.0 | | bitsandbytes | 0.39.0 | 0.43.1 | -| vllm | 0.4.3 | 0.4.3 | -| flash-attn | 2.3.0 | 2.5.9 | +| vllm | 0.4.3 | 0.5.0 | +| flash-attn | 2.3.0 | 2.6.3 | ### Hardware Requirement diff --git a/README_zh.md b/README_zh.md index 3a7724d1..812b7b28 100644 --- a/README_zh.md +++ b/README_zh.md @@ -166,8 +166,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | @@ -300,20 +300,20 @@ huggingface-cli login | 必需项 | 至少 | 推荐 | | ------------ | ------- | --------- | | python | 3.8 | 3.11 | -| torch | 1.13.1 | 2.3.0 | -| transformers | 4.41.2 | 4.41.2 | -| datasets | 2.16.0 | 2.19.2 | -| accelerate | 0.30.1 | 0.30.1 | -| peft | 0.11.1 | 0.11.1 | -| trl | 0.8.6 | 0.9.4 | +| torch | 1.13.1 | 2.4.0 | +| transformers | 4.41.2 | 4.43.4 | +| datasets | 2.16.0 | 2.20.0 | +| accelerate | 0.30.1 | 0.32.0 | +| peft | 0.11.1 | 0.12.0 | +| trl | 0.8.6 | 0.9.6 | | 可选项 | 至少 | 推荐 | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | | deepspeed | 0.10.0 | 0.14.0 | | bitsandbytes | 0.39.0 | 0.43.1 | -| vllm | 0.4.3 | 0.4.3 | -| flash-attn | 2.3.0 | 2.5.9 | +| vllm | 0.4.3 | 0.5.0 | +| flash-attn | 2.3.0 | 2.6.3 | ### 硬件依赖 diff --git a/requirements.txt b/requirements.txt index 7380add4..93e83530 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -transformers>=4.41.2 -datasets>=2.16.0 -accelerate>=0.30.1 -peft>=0.11.1 -trl>=0.8.6 +transformers>=4.41.2,<=4.43.4 +datasets>=2.16.0,<=2.20.0 +accelerate>=0.30.1,<=0.32.0 +peft>=0.11.1,<=0.12.0 +trl>=0.8.6,<=0.9.6 gradio>=4.0.0 pandas>=2.0.0 scipy diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 28f5144a..7b602a92 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -20,19 +20,17 @@ Level: Dependency graph: main: - transformers>=4.41.2 - datasets>=2.16.0 - accelerate>=0.30.1 - peft>=0.11.1 - trl>=0.8.6 + transformers>=4.41.2,<=4.43.4 + datasets>=2.16.0,<=2.20.0 + accelerate>=0.30.1,<=0.32.0 + peft>=0.11.1,<=0.12.0 + trl>=0.8.6,<=0.9.6 attention: transformers>=4.42.4 (gemma+fa2) longlora: - transformers>=4.41.2,<=4.42.4 + transformers>=4.41.2,<=4.43.4 packing: - transformers>=4.41.2,<=4.42.4 - patcher: - transformers==4.41.2 (chatglm) + transformers>=4.41.2,<=4.43.4 """ from .cli import VERSION diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index c413c51d..4531db4a 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -535,10 +535,6 @@ register_model_group( DownloadSource.DEFAULT: "google/gemma-2-2b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b", }, - "Gemma-2-2B-Chat": { - DownloadSource.DEFAULT: "google/gemma-2-2b-it", - DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", - }, "Gemma-2-9B": { DownloadSource.DEFAULT: "google/gemma-2-9b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", @@ -547,6 +543,10 @@ register_model_group( DownloadSource.DEFAULT: "google/gemma-2-27b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", }, + "Gemma-2-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-2b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", + }, "Gemma-2-9B-Chat": { DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index d7329b06..c1395552 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,11 +79,11 @@ def check_dependencies() -> None: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: - require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") - require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") - require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") - require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") - require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") + require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") + require_version("datasets>=2.16.0,<=2.20.0", "To fix: pip install datasets>=2.16.0,<=2.20.0") + require_version("accelerate>=0.30.1,<=0.32.0", "To fix: pip install accelerate>=0.30.1,<=0.32.0") + require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") + require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index a9072103..ae270d1b 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -70,6 +70,11 @@ def is_starlette_available(): return _is_package_available("sse_starlette") +@lru_cache +def is_transformers_version_greater_than_4_43(): + return _get_package_version("transformers") >= version.parse("4.43.0") + + def is_uvicorn_available(): return _is_package_available("uvicorn") diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index da53baa2..96e2c8a9 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -36,7 +36,7 @@ def configure_attn_implementation( if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if is_flash_attn_2_available(): require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") - require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0") + require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") model_args.flash_attn = "fa2" else: diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 53570a16..e518aefb 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -35,6 +35,7 @@ from transformers.utils.versions import require_version from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.logging import get_logger +from ...extras.packages import is_transformers_version_greater_than_4_43 if TYPE_CHECKING: @@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__) # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py def llama_attention_forward( self: "LlamaAttention", - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + hidden_states: "torch.Tensor", + attention_mask: Optional["torch.Tensor"] = None, + position_ids: Optional["torch.LongTensor"] = None, past_key_value: Optional["Cache"] = None, output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional["torch.LongTensor"] = None, + position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: bsz, q_len, _ = hidden_states.size() query_states: "torch.Tensor" = self.q_proj(hidden_states) @@ -68,7 +70,11 @@ def llama_attention_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -130,14 +136,15 @@ def llama_attention_forward( # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py def llama_flash_attention_2_forward( self: "LlamaFlashAttention2", - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + hidden_states: "torch.Tensor", + attention_mask: Optional["torch.Tensor"] = None, + position_ids: Optional["torch.LongTensor"] = None, past_key_value: Optional["Cache"] = None, output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional["torch.LongTensor"] = None, + position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: # LlamaFlashAttention2 attention does not support output_attentions output_attentions = False @@ -151,7 +158,11 @@ def llama_flash_attention_2_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -198,9 +209,24 @@ def llama_flash_attention_2_forward( if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - attn_output: "torch.Tensor" = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate - ) + if is_transformers_version_greater_than_4_43(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + attn_output: "torch.Tensor" = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.size(1), + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + else: + attn_output: "torch.Tensor" = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate + ) if getattr(self.config, "group_size_ratio", None) and self.training: # shift back attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) @@ -225,14 +251,15 @@ def llama_flash_attention_2_forward( # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py def llama_sdpa_attention_forward( self: "LlamaSdpaAttention", - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + hidden_states: "torch.Tensor", + attention_mask: Optional["torch.Tensor"] = None, + position_ids: Optional["torch.LongTensor"] = None, past_key_value: Optional["Cache"] = None, output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional["torch.LongTensor"] = None, + position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: if output_attentions: transformers_logger.warning_once( "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" @@ -258,7 +285,11 @@ def llama_sdpa_attention_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -322,7 +353,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") + require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 674e0b4a..ded7f295 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple import torch import torch.nn.functional as F -import transformers.models from transformers.utils.versions import require_version from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.logging import get_logger +from ...extras.packages import is_transformers_version_greater_than_4_43 if TYPE_CHECKING: @@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor def _patch_for_block_diag_attn(model_type: str) -> None: - require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") + require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") + if is_transformers_version_greater_than_4_43(): + import transformers.modeling_flash_attention_utils + + transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data + return + + import transformers.models + if model_type == "cohere": transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data elif model_type == "falcon": diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 657dd8f3..3b05317d 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -162,11 +162,12 @@ class PissaConvertCallback(TrainerCallback): setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) model.save_pretrained( pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir - ) + ) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0) model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.set_adapter("default") - if "pissa_init" in model.peft_config.keys(): + if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0) model.delete_adapter("pissa_init") + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index 1ca152c4..0ea37787 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -71,7 +71,7 @@ def create_web_demo() -> "gr.Blocks": engine = Engine(pure_chat=True) with gr.Blocks(title="Web Demo", css=CSS) as demo: - lang = gr.Dropdown(choices=["en", "zh"]) + lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1) engine.manager.add_elems("top", dict(lang=lang)) _, _, chat_elems = create_chat_box(engine, visible=True) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 01d6fe29..0a8ca68a 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -362,7 +362,6 @@ LOCALES = { "label": "학습률", "info": "AdamW의 초기 학습률.", }, - }, "num_train_epochs": { "en": {