commit
b42a145253
|
@ -22,6 +22,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/10/21] We supported [NEFTune](https://arxiv.org/abs/2310.05914) optimization . Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||||
|
|
||||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||||
|
|
||||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||||
|
|
|
@ -22,6 +22,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[23/10/21] 我们支持了 [NEFTune](https://arxiv.org/abs/2310.05914) 优化。试试`--neftune_noise_alpha` 参数来激活 NEFTune,例如,`--neftune_noise_alpha 5`。
|
||||||
|
|
||||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||||
|
|
||||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||||
|
|
|
@ -75,6 +75,10 @@ class FinetuningArguments:
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
|
neftune_noise_alpha: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The alpha parameter for the NEFTune noise. By setting this the NEFTune optimization will be activated."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||||
|
|
|
@ -3,8 +3,10 @@ import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer, PreTrainedModel, Trainer
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
@ -21,6 +23,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: Union["PreTrainedModel", nn.Module] = None, neftune_noise_alpha: Optional[float] = 0, **kwargs):
|
||||||
|
super().__init__(model, **kwargs)
|
||||||
|
self.neftune_noise_alpha = neftune_noise_alpha
|
||||||
|
self._neftune_activated = False
|
||||||
|
|
||||||
|
if self.neftune_noise_alpha:
|
||||||
|
self._activate_neftune(model)
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
@ -99,3 +109,71 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
for pred, label in zip(decoded_preds, decoded_labels):
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
||||||
writer.write("\n".join(res))
|
writer.write("\n".join(res))
|
||||||
|
|
||||||
|
|
||||||
|
@wraps(Trainer.train)
|
||||||
|
def train(self, *args, **kwargs):
|
||||||
|
output = super().train(*args, **kwargs)
|
||||||
|
|
||||||
|
# After training we make sure to retrieve back the original forward pass method
|
||||||
|
# for the embedding layer.
|
||||||
|
if self.neftune_noise_alpha is not None:
|
||||||
|
self._deactivate_neftune(self.model)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _toggle_neftune(self, model, activate=True):
|
||||||
|
"""Toggle NEFTune optimization for a model (i.e. activate or deactivate).
|
||||||
|
This optimization based on this paper: https://arxiv.org/abs/2310.05914
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model : PreTrainedModel or PeftModel
|
||||||
|
The model to toggle the noise for.
|
||||||
|
activate : bool, optional (default=True)
|
||||||
|
Whether to activate the noise or not.
|
||||||
|
"""
|
||||||
|
if activate == self._neftune_activated:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._neftune_activated = activate
|
||||||
|
|
||||||
|
embeddings = (model.get_input_embeddings() if isinstance(model, PreTrainedModel)
|
||||||
|
else model.base_model.get_input_embeddings() if isinstance(model, PeftModel)
|
||||||
|
else None)
|
||||||
|
|
||||||
|
if embeddings:
|
||||||
|
if activate:
|
||||||
|
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
||||||
|
embeddings._trl_old_forward = embeddings.forward
|
||||||
|
neftune_method = _neftune_forward_function.__get__(embeddings, embeddings.__class__)
|
||||||
|
setattr(embeddings, "forward", neftune_method)
|
||||||
|
logger.info("NEFTune activated with alpha: ", self.neftune_noise_alpha)
|
||||||
|
elif hasattr(embeddings, "_trl_old_forward"):
|
||||||
|
embeddings.forward = embeddings._trl_old_forward
|
||||||
|
del embeddings._trl_old_forward
|
||||||
|
del embeddings.neftune_noise_alpha
|
||||||
|
logger.info("NEFTune deactivated")
|
||||||
|
|
||||||
|
_activate_neftune = lambda self, model: self._toggle_neftune(model, activate=True)
|
||||||
|
_deactivate_neftune = lambda self, model: self._toggle_neftune(model, activate=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _neftune_forward_function(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This code is adapted from the original source code that can be found here: https://github.com/neelsjain/NEFTune
|
||||||
|
"""
|
||||||
|
embeddings = torch.nn.functional.embedding(
|
||||||
|
input,
|
||||||
|
self.weight,
|
||||||
|
self.padding_idx,
|
||||||
|
self.max_norm,
|
||||||
|
self.norm_type,
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||||
|
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
|
||||||
|
embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
|
@ -53,6 +53,7 @@ def run_sft(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
|
neftune_noise_alpha=finetuning_args.neftune_noise_alpha,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue