fix #5048
This commit is contained in:
parent
c2921b9960
commit
b7ca6c8dc1
16
README.md
16
README.md
|
@ -300,20 +300,20 @@ huggingface-cli login
|
||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.8 | 3.11 |
|
||||||
| torch | 1.13.1 | 2.3.0 |
|
| torch | 1.13.1 | 2.4.0 |
|
||||||
| transformers | 4.41.2 | 4.41.2 |
|
| transformers | 4.41.2 | 4.43.4 |
|
||||||
| datasets | 2.16.0 | 2.19.2 |
|
| datasets | 2.16.0 | 2.20.0 |
|
||||||
| accelerate | 0.30.1 | 0.30.1 |
|
| accelerate | 0.30.1 | 0.32.0 |
|
||||||
| peft | 0.11.1 | 0.11.1 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.4 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| Optional | Minimum | Recommend |
|
| Optional | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.4.3 |
|
| vllm | 0.4.3 | 0.5.0 |
|
||||||
| flash-attn | 2.3.0 | 2.5.9 |
|
| flash-attn | 2.3.0 | 2.6.3 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
|
|
18
README_zh.md
18
README_zh.md
|
@ -166,8 +166,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
|
||||||
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
|
@ -300,20 +300,20 @@ huggingface-cli login
|
||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.8 | 3.11 |
|
||||||
| torch | 1.13.1 | 2.3.0 |
|
| torch | 1.13.1 | 2.4.0 |
|
||||||
| transformers | 4.41.2 | 4.41.2 |
|
| transformers | 4.41.2 | 4.43.4 |
|
||||||
| datasets | 2.16.0 | 2.19.2 |
|
| datasets | 2.16.0 | 2.20.0 |
|
||||||
| accelerate | 0.30.1 | 0.30.1 |
|
| accelerate | 0.30.1 | 0.32.0 |
|
||||||
| peft | 0.11.1 | 0.11.1 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.4 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| 可选项 | 至少 | 推荐 |
|
| 可选项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.4.3 |
|
| vllm | 0.4.3 | 0.5.0 |
|
||||||
| flash-attn | 2.3.0 | 2.5.9 |
|
| flash-attn | 2.3.0 | 2.6.3 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
transformers>=4.41.2
|
transformers>=4.41.2,<=4.43.4
|
||||||
datasets>=2.16.0
|
datasets>=2.16.0,<=2.20.0
|
||||||
accelerate>=0.30.1
|
accelerate>=0.30.1,<=0.32.0
|
||||||
peft>=0.11.1
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6
|
trl>=0.8.6,<=0.9.6
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
|
|
|
@ -20,19 +20,17 @@ Level:
|
||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2
|
transformers>=4.41.2,<=4.43.4
|
||||||
datasets>=2.16.0
|
datasets>=2.16.0,<=2.20.0
|
||||||
accelerate>=0.30.1
|
accelerate>=0.30.1,<=0.32.0
|
||||||
peft>=0.11.1
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6
|
trl>=0.8.6,<=0.9.6
|
||||||
attention:
|
attention:
|
||||||
transformers>=4.42.4 (gemma+fa2)
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
longlora:
|
longlora:
|
||||||
transformers>=4.41.2,<=4.42.4
|
transformers>=4.41.2,<=4.43.4
|
||||||
packing:
|
packing:
|
||||||
transformers>=4.41.2,<=4.42.4
|
transformers>=4.41.2,<=4.43.4
|
||||||
patcher:
|
|
||||||
transformers==4.41.2 (chatglm)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .cli import VERSION
|
from .cli import VERSION
|
||||||
|
|
|
@ -535,10 +535,6 @@ register_model_group(
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-2b",
|
DownloadSource.DEFAULT: "google/gemma-2-2b",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
|
||||||
},
|
},
|
||||||
"Gemma-2-2B-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
|
||||||
},
|
|
||||||
"Gemma-2-9B": {
|
"Gemma-2-9B": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
||||||
|
@ -547,6 +543,10 @@ register_model_group(
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
||||||
},
|
},
|
||||||
|
"Gemma-2-2B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
||||||
|
},
|
||||||
"Gemma-2-9B-Chat": {
|
"Gemma-2-9B-Chat": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||||
|
|
|
@ -79,11 +79,11 @@ def check_dependencies() -> None:
|
||||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
else:
|
else:
|
||||||
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||||
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
|
require_version("datasets>=2.16.0,<=2.20.0", "To fix: pip install datasets>=2.16.0,<=2.20.0")
|
||||||
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
|
require_version("accelerate>=0.30.1,<=0.32.0", "To fix: pip install accelerate>=0.30.1,<=0.32.0")
|
||||||
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
|
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||||
|
|
|
@ -70,6 +70,11 @@ def is_starlette_available():
|
||||||
return _is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_transformers_version_greater_than_4_43():
|
||||||
|
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ def configure_attn_implementation(
|
||||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||||
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
|
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||||
model_args.flash_attn = "fa2"
|
model_args.flash_attn = "fa2"
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_attention_forward(
|
def llama_attention_forward(
|
||||||
self: "LlamaAttention",
|
self: "LlamaAttention",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: "torch.Tensor",
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional["torch.Tensor"] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional["torch.LongTensor"] = None,
|
||||||
past_key_value: Optional["Cache"] = None,
|
past_key_value: Optional["Cache"] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional["torch.LongTensor"] = None,
|
||||||
|
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
|
@ -68,7 +70,11 @@ def llama_attention_forward(
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
if position_embeddings is None:
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -130,14 +136,15 @@ def llama_attention_forward(
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_flash_attention_2_forward(
|
def llama_flash_attention_2_forward(
|
||||||
self: "LlamaFlashAttention2",
|
self: "LlamaFlashAttention2",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: "torch.Tensor",
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional["torch.Tensor"] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional["torch.LongTensor"] = None,
|
||||||
past_key_value: Optional["Cache"] = None,
|
past_key_value: Optional["Cache"] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional["torch.LongTensor"] = None,
|
||||||
|
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||||
# LlamaFlashAttention2 attention does not support output_attentions
|
# LlamaFlashAttention2 attention does not support output_attentions
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
|
@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
if position_embeddings is None:
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||||
|
|
||||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
if is_transformers_version_greater_than_4_43():
|
||||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
)
|
|
||||||
|
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_states.size(1),
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_sdpa_attention_forward(
|
def llama_sdpa_attention_forward(
|
||||||
self: "LlamaSdpaAttention",
|
self: "LlamaSdpaAttention",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: "torch.Tensor",
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional["torch.Tensor"] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional["torch.LongTensor"] = None,
|
||||||
past_key_value: Optional["Cache"] = None,
|
past_key_value: Optional["Cache"] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional["torch.LongTensor"] = None,
|
||||||
|
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
transformers_logger.warning_once(
|
transformers_logger.warning_once(
|
||||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||||
|
@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
if position_embeddings is None:
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
|
||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
|
|
@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers.models
|
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||||
|
|
||||||
|
|
||||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||||
|
if is_transformers_version_greater_than_4_43():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||||
|
return
|
||||||
|
|
||||||
|
import transformers.models
|
||||||
|
|
||||||
if model_type == "cohere":
|
if model_type == "cohere":
|
||||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||||
elif model_type == "falcon":
|
elif model_type == "falcon":
|
||||||
|
|
|
@ -162,11 +162,12 @@ class PissaConvertCallback(TrainerCallback):
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
|
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
|
||||||
)
|
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
|
||||||
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||||
model.set_adapter("default")
|
model.set_adapter("default")
|
||||||
if "pissa_init" in model.peft_config.keys():
|
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
|
||||||
model.delete_adapter("pissa_init")
|
model.delete_adapter("pissa_init")
|
||||||
|
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,7 @@ def create_web_demo() -> "gr.Blocks":
|
||||||
engine = Engine(pure_chat=True)
|
engine = Engine(pure_chat=True)
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||||
lang = gr.Dropdown(choices=["en", "zh"])
|
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
|
||||||
engine.manager.add_elems("top", dict(lang=lang))
|
engine.manager.add_elems("top", dict(lang=lang))
|
||||||
|
|
||||||
_, _, chat_elems = create_chat_box(engine, visible=True)
|
_, _, chat_elems = create_chat_box(engine, visible=True)
|
||||||
|
|
|
@ -362,7 +362,6 @@ LOCALES = {
|
||||||
"label": "학습률",
|
"label": "학습률",
|
||||||
"info": "AdamW의 초기 학습률.",
|
"info": "AdamW의 초기 학습률.",
|
||||||
},
|
},
|
||||||
|
|
||||||
},
|
},
|
||||||
"num_train_epochs": {
|
"num_train_epochs": {
|
||||||
"en": {
|
"en": {
|
||||||
|
|
Loading…
Reference in New Issue