forked from p04798526/LLaMA-Factory-Mirror
Merge pull request #4224 from chuan298/main
Implement efficient packing without cross-contamination attention
This commit is contained in:
commit
87d9b2d005
|
@ -1,4 +1,7 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
||||||
|
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,6 +22,46 @@ import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||||
|
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
```
|
||||||
|
[[1, 1, 2, 2, 2, 0]]
|
||||||
|
```
|
||||||
|
->
|
||||||
|
```
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[o, x, x, x, x, x],
|
||||||
|
[o, o, x, x, x, x],
|
||||||
|
[x, x, o, x, x, x],
|
||||||
|
[x, x, o, o, x, x],
|
||||||
|
[x, x, o, o, o, x],
|
||||||
|
[x, x, o, x, x, x],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
```
|
||||||
|
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||||
|
"""
|
||||||
|
bsz, seq_len = attention_mask_with_indices.size()
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||||
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
||||||
|
# Create a block-diagonal mask.
|
||||||
|
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part
|
||||||
|
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
||||||
|
# Invert the attention mask.
|
||||||
|
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
||||||
|
return attention_mask_4d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -160,22 +160,30 @@ def preprocess_packed_supervised_dataset(
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||||
for knapsack in knapsacks:
|
for knapsack in knapsacks:
|
||||||
packed_input_ids, packed_labels = [], []
|
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||||
for length in knapsack:
|
for i, length in enumerate(knapsack):
|
||||||
index = length2indexes[length].pop()
|
index = length2indexes[length].pop()
|
||||||
packed_input_ids += batch_input_ids[index]
|
packed_input_ids += batch_input_ids[index]
|
||||||
packed_labels += batch_labels[index]
|
packed_labels += batch_labels[index]
|
||||||
|
if data_args.neat_packing:
|
||||||
|
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||||
|
else:
|
||||||
|
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||||
|
|
||||||
if len(packed_input_ids) < data_args.cutoff_len:
|
if len(packed_input_ids) < data_args.cutoff_len:
|
||||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||||
packed_labels += [IGNORE_INDEX] * pad_length
|
packed_labels += [IGNORE_INDEX] * pad_length
|
||||||
|
if data_args.neat_packing:
|
||||||
|
packed_attention_masks += [0] * pad_length
|
||||||
|
else:
|
||||||
|
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||||
|
|
||||||
if len(packed_input_ids) != data_args.cutoff_len:
|
if len(packed_input_ids) != data_args.cutoff_len:
|
||||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||||
|
|
||||||
model_inputs["input_ids"].append(packed_input_ids)
|
model_inputs["input_ids"].append(packed_input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||||
model_inputs["labels"].append(packed_labels)
|
model_inputs["labels"].append(packed_labels)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
|
@ -78,6 +78,18 @@ TRAINING_STAGES = {
|
||||||
|
|
||||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||||
|
|
||||||
|
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||||
|
"falcon",
|
||||||
|
"gemma",
|
||||||
|
"gemma2",
|
||||||
|
"llama",
|
||||||
|
"mistral",
|
||||||
|
"phi",
|
||||||
|
"phi3",
|
||||||
|
"qwen2",
|
||||||
|
"starcoder2",
|
||||||
|
}
|
||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
|
@ -83,9 +83,7 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: bool = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
|
||||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
val_size: float = field(
|
val_size: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
|
@ -93,9 +91,11 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
packing: Optional[bool] = field(
|
packing: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
)
|
||||||
},
|
neat_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||||
)
|
)
|
||||||
tool_format: Optional[str] = field(
|
tool_format: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -112,3 +112,6 @@ class DataArguments:
|
||||||
|
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
|
if self.neat_packing and not self.packing:
|
||||||
|
raise ValueError("`neat_packing` requires `packing` is True.")
|
||||||
|
|
|
@ -226,6 +226,7 @@ class ModelArguments:
|
||||||
self.compute_dtype: Optional["torch.dtype"] = None
|
self.compute_dtype: Optional["torch.dtype"] = None
|
||||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
self.model_max_length: Optional[int] = None
|
self.model_max_length: Optional[int] = None
|
||||||
|
self.block_diag_attn: bool = False
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
@ -253,4 +254,5 @@ class ModelArguments:
|
||||||
new_arg.compute_dtype = old_arg.compute_dtype
|
new_arg.compute_dtype = old_arg.compute_dtype
|
||||||
new_arg.device_map = old_arg.device_map
|
new_arg.device_map = old_arg.device_map
|
||||||
new_arg.model_max_length = old_arg.model_max_length
|
new_arg.model_max_length = old_arg.model_max_length
|
||||||
|
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||||
return new_arg
|
return new_arg
|
||||||
|
|
|
@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
|
if finetuning_args.stage != "sft" and data_args.neat_packing:
|
||||||
|
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
|
@ -311,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
|
||||||
model_args.device_map = {"": get_current_device()}
|
model_args.device_map = {"": get_current_device()}
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
model_args.block_diag_attn = data_args.neat_packing
|
||||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||||
|
|
||||||
# Log on each process the small summary
|
# Log on each process the small summary
|
||||||
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
# Copyright 2024 Musab Gultekin and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is based on the Musab Gultekin's functionary library.
|
||||||
|
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 Musab Gultekin
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import transformers.models
|
||||||
|
|
||||||
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Gets the sequnce lengths in the current batch.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
```
|
||||||
|
[
|
||||||
|
[1, 1, 2, 2, 2, 0],
|
||||||
|
[1, 2, 2, 3, 3, 3],
|
||||||
|
]
|
||||||
|
```
|
||||||
|
->
|
||||||
|
```
|
||||||
|
[2, 3, 1, 2, 3]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
bsz = attention_mask.size(0)
|
||||||
|
dtype, device = attention_mask.dtype, attention_mask.device
|
||||||
|
max_num = torch.max(attention_mask)
|
||||||
|
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||||
|
for i in range(max_num):
|
||||||
|
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||||
|
|
||||||
|
counts = counts.flatten()
|
||||||
|
seqlens = counts[counts.nonzero().squeeze()]
|
||||||
|
return seqlens
|
||||||
|
|
||||||
|
|
||||||
|
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
|
||||||
|
r"""
|
||||||
|
Prepares the indices and seqlens for flash attn varlen function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
indices: indices of non-masked tokens from the flattened sequence.
|
||||||
|
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
|
||||||
|
max_seqlen_in_batch: the largest seqlen in the current batch.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
```
|
||||||
|
[
|
||||||
|
[1, 1, 2, 2, 2, 0],
|
||||||
|
[1, 2, 2, 3, 3, 3],
|
||||||
|
]
|
||||||
|
```
|
||||||
|
->
|
||||||
|
```
|
||||||
|
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
||||||
|
[0, 2, 5, 6, 8, 11]
|
||||||
|
3
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
return indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
|
||||||
|
def patch_for_block_diag_attn(model_type: str) -> None:
|
||||||
|
if model_type == "falcon":
|
||||||
|
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "gemma":
|
||||||
|
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "gemma2":
|
||||||
|
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "llama":
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "mistral":
|
||||||
|
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "phi":
|
||||||
|
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "phi3":
|
||||||
|
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "qwen2":
|
||||||
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "starcoder2":
|
||||||
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
|
||||||
|
|
||||||
|
|
||||||
|
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
|
if not is_trainable or not model_args.block_diag_attn:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_type = getattr(config, "model_type", None)
|
||||||
|
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||||
|
patch_for_block_diag_attn(model_type)
|
||||||
|
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Current model does not support block diagonal attention.")
|
|
@ -29,6 +29,7 @@ from .model_utils.checkpointing import prepare_model_for_training
|
||||||
from .model_utils.embedding import resize_embedding_layer
|
from .model_utils.embedding import resize_embedding_layer
|
||||||
from .model_utils.longlora import configure_longlora
|
from .model_utils.longlora import configure_longlora
|
||||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||||
|
from .model_utils.packing import configure_packing
|
||||||
from .model_utils.quantization import configure_quantization
|
from .model_utils.quantization import configure_quantization
|
||||||
from .model_utils.rope import configure_rope
|
from .model_utils.rope import configure_rope
|
||||||
from .model_utils.valuehead import prepare_valuehead_model
|
from .model_utils.valuehead import prepare_valuehead_model
|
||||||
|
@ -73,6 +74,7 @@ def patch_config(
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
configure_visual_model(config)
|
configure_visual_model(config)
|
||||||
|
configure_packing(config, model_args, is_trainable)
|
||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
|
|
|
@ -95,11 +95,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
resize_vocab = gr.Checkbox()
|
|
||||||
packing = gr.Checkbox()
|
packing = gr.Checkbox()
|
||||||
|
neat_packing = gr.Checkbox()
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
upcast_layernorm = gr.Checkbox()
|
resize_vocab = gr.Checkbox()
|
||||||
use_llama_pro = gr.Checkbox()
|
use_llama_pro = gr.Checkbox()
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -113,9 +113,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
warmup_steps,
|
warmup_steps,
|
||||||
neftune_alpha,
|
neftune_alpha,
|
||||||
optim,
|
optim,
|
||||||
resize_vocab,
|
|
||||||
packing,
|
packing,
|
||||||
upcast_layernorm,
|
neat_packing,
|
||||||
|
resize_vocab,
|
||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
shift_attn,
|
shift_attn,
|
||||||
report_to,
|
report_to,
|
||||||
|
@ -129,9 +129,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
neftune_alpha=neftune_alpha,
|
neftune_alpha=neftune_alpha,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
resize_vocab=resize_vocab,
|
|
||||||
packing=packing,
|
packing=packing,
|
||||||
upcast_layernorm=upcast_layernorm,
|
neat_packing=neat_packing,
|
||||||
|
resize_vocab=resize_vocab,
|
||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
shift_attn=shift_attn,
|
shift_attn=shift_attn,
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
|
|
|
@ -494,20 +494,6 @@ LOCALES = {
|
||||||
"info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。",
|
"info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"resize_vocab": {
|
|
||||||
"en": {
|
|
||||||
"label": "Resize token embeddings",
|
|
||||||
"info": "Resize the tokenizer vocab and the embedding layers.",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "Изменение размера токенных эмбеддингов",
|
|
||||||
"info": "Изменить размер словаря токенизатора и слоев эмбеддинга.",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "更改词表大小",
|
|
||||||
"info": "更改分词器词表和嵌入层的大小。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"packing": {
|
"packing": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Pack sequences",
|
"label": "Pack sequences",
|
||||||
|
@ -522,18 +508,32 @@ LOCALES = {
|
||||||
"info": "将序列打包为等长样本。",
|
"info": "将序列打包为等长样本。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"upcast_layernorm": {
|
"neat_packing": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Upcast LayerNorm",
|
"label": "Use neat packing",
|
||||||
"info": "Upcast weights of layernorm in float32.",
|
"info": "Avoid cross-attention between packed sequences.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Приведение весов LayerNorm",
|
"label": "Используйте аккуратную упаковку",
|
||||||
"info": "Приведение весов LayerNorm к float32.",
|
"info": "избегайте перекрестного внимания между упакованными последовательностями.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "缩放归一化层",
|
"label": "使用无污染打包",
|
||||||
"info": "将归一化层权重缩放至 32 位精度。",
|
"info": "避免打包后的序列产生交叉注意力。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"resize_vocab": {
|
||||||
|
"en": {
|
||||||
|
"label": "Resize token embeddings",
|
||||||
|
"info": "Resize the tokenizer vocab and the embedding layers.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Изменение размера токенных эмбеддингов",
|
||||||
|
"info": "Изменить размер словаря токенизатора и слоев эмбеддинга.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "更改词表大小",
|
||||||
|
"info": "更改分词器词表和嵌入层的大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"use_llama_pro": {
|
"use_llama_pro": {
|
||||||
|
|
|
@ -138,9 +138,9 @@ class Runner:
|
||||||
warmup_steps=get("train.warmup_steps"),
|
warmup_steps=get("train.warmup_steps"),
|
||||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||||
optim=get("train.optim"),
|
optim=get("train.optim"),
|
||||||
|
packing=get("train.packing") or get("train.neat_packing"),
|
||||||
|
neat_packing=get("train.neat_packing"),
|
||||||
resize_vocab=get("train.resize_vocab"),
|
resize_vocab=get("train.resize_vocab"),
|
||||||
packing=get("train.packing"),
|
|
||||||
upcast_layernorm=get("train.upcast_layernorm"),
|
|
||||||
use_llama_pro=get("train.use_llama_pro"),
|
use_llama_pro=get("train.use_llama_pro"),
|
||||||
shift_attn=get("train.shift_attn"),
|
shift_attn=get("train.shift_attn"),
|
||||||
report_to="all" if get("train.report_to") else "none",
|
report_to="all" if get("train.report_to") else "none",
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llamafactory.data.collator import prepare_4d_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_4d_attention_mask():
|
||||||
|
o = 0.0
|
||||||
|
x = torch.finfo(torch.float16).min
|
||||||
|
attention_mask_with_indices = torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 2, 2, 2, 0],
|
||||||
|
[1, 2, 2, 3, 3, 3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
attention_mask_computed = prepare_4d_attention_mask(attention_mask_with_indices, torch.float16)
|
||||||
|
attention_mask_expected = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[o, x, x, x, x, x],
|
||||||
|
[o, o, x, x, x, x],
|
||||||
|
[x, x, o, x, x, x],
|
||||||
|
[x, x, o, o, x, x],
|
||||||
|
[x, x, o, o, o, x],
|
||||||
|
[x, x, x, x, x, x],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[o, x, x, x, x, x],
|
||||||
|
[x, o, x, x, x, x],
|
||||||
|
[x, o, o, x, x, x],
|
||||||
|
[x, x, x, o, x, x],
|
||||||
|
[x, x, x, o, o, x],
|
||||||
|
[x, x, x, o, o, o],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
],
|
||||||
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
|
||||||
|
assert torch.all(attention_mask_computed == attention_mask_expected)
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_seqlens_in_batch():
|
||||||
|
attention_mask_with_indices = torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 2, 2, 2, 0],
|
||||||
|
[1, 2, 2, 3, 3, 3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
|
||||||
|
assert list(seqlens_in_batch.size()) == [5]
|
||||||
|
assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_unpad_data():
|
||||||
|
attention_mask_with_indices = torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 2, 2, 2, 0],
|
||||||
|
[1, 2, 2, 3, 3, 3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
|
||||||
|
assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]))
|
||||||
|
assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32))
|
||||||
|
assert max_seqlen_in_batch == 3
|
Loading…
Reference in New Issue