fix gemma2 attention
This commit is contained in:
parent
7b19e99ed7
commit
2f6af73da2
|
@ -12,7 +12,29 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||
r"""
|
||||
Efficient fine-tuning of large language models.
|
||||
|
||||
Level:
|
||||
api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2
|
||||
datasets>=2.16.0
|
||||
accelerate>=0.30.1
|
||||
peft>=0.11.1
|
||||
trl>=0.8.6
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.42.4
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.42.4
|
||||
patcher:
|
||||
transformers==4.41.2 (chatglm)
|
||||
"""
|
||||
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
|
|
|
@ -28,11 +28,10 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
|||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
|
||||
e.g.
|
||||
```
|
||||
```python
|
||||
# input
|
||||
[[1, 1, 2, 2, 2, 0]]
|
||||
```
|
||||
->
|
||||
```
|
||||
# output
|
||||
[
|
||||
[
|
||||
[
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
@ -31,15 +32,17 @@ logger = get_logger(__name__)
|
|||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||
if model_args.flash_attn == "auto":
|
||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn != "disabled":
|
||||
logger.warning(
|
||||
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
|
||||
"Will proceed at your own risk.".format(model_args.flash_attn)
|
||||
)
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
else:
|
||||
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.")
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
|
|
@ -326,7 +326,7 @@ def llama_sdpa_attention_forward(
|
|||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
|
|
@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.logging import get_logger
|
||||
|
@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
|||
Gets the sequnce lengths in the current batch.
|
||||
|
||||
e.g.
|
||||
```
|
||||
```python
|
||||
# input
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
```
|
||||
->
|
||||
```
|
||||
# output
|
||||
[2, 3, 1, 2, 3]
|
||||
```
|
||||
"""
|
||||
|
@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||
max_seqlen_in_batch: the largest seqlen in the current batch.
|
||||
|
||||
e.g.
|
||||
```
|
||||
```python
|
||||
# input
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
```
|
||||
->
|
||||
```
|
||||
# output
|
||||
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
||||
[0, 2, 5, 6, 8, 11]
|
||||
3
|
||||
|
@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def patch_for_block_diag_attn(model_type: str) -> None:
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
if model_type == "cohere":
|
||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||
elif model_type == "falcon":
|
||||
|
@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
patch_for_block_diag_attn(model_type)
|
||||
_patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
else:
|
||||
raise ValueError("Current model does not support block diagonal attention.")
|
||||
|
|
|
@ -126,7 +126,6 @@ def configure_quantization(
|
|||
require_version("autoawq", "To fix: pip install autoawq")
|
||||
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from peft import PeftModel
|
|||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
|
@ -88,6 +89,9 @@ def patch_config(
|
|||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
|
|
Loading…
Reference in New Issue