fix gemma2 attention

This commit is contained in:
hiyouga 2024-07-13 23:33:45 +08:00
parent 7b19e99ed7
commit 2f6af73da2
7 changed files with 53 additions and 26 deletions

View File

@ -12,7 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Level: api, webui > chat, eval, train > data, model > hparams > extras
r"""
Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.42.4
packing:
transformers>=4.41.2,<=4.42.4
patcher:
transformers==4.41.2 (chatglm)
"""
from .cli import VERSION

View File

@ -28,11 +28,10 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```
```python
# input
[[1, 1, 2, 2, 2, 0]]
```
->
```
# output
[
[
[

View File

@ -15,6 +15,7 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras.logging import get_logger
@ -31,15 +32,17 @@ logger = get_logger(__name__)
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
if model_args.flash_attn == "auto":
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn != "disabled":
logger.warning(
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
"Will proceed at your own risk.".format(model_args.flash_attn)
)
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa":
raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.")
if model_args.flash_attn == "auto":
return

View File

@ -326,7 +326,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
Gets the sequnce lengths in the current batch.
e.g.
```
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
# output
[2, 3, 1, 2, 3]
```
"""
@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
max_seqlen_in_batch: the largest seqlen in the current batch.
e.g.
```
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
# output
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11]
3
@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return indices, cu_seqlens, max_seqlen_in_batch
def patch_for_block_diag_attn(model_type: str) -> None:
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
patch_for_block_diag_attn(model_type)
_patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")

View File

@ -126,7 +126,6 @@ def configure_quantization(
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2

View File

@ -21,6 +21,7 @@ from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
@ -88,6 +89,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "chatglm":
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())