fix version checking
This commit is contained in:
parent
d1587c80de
commit
3016e65657
6
Makefile
6
Makefile
|
@ -1,11 +1,11 @@
|
||||||
.PHONY: quality style
|
.PHONY: quality style
|
||||||
|
|
||||||
check_dirs := src tests
|
check_dirs := scripts src
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
|
|
||||||
style:
|
style:
|
||||||
ruff $(check_dirs) --fix
|
ruff check $(check_dirs) --fix
|
||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
18
README.md
18
README.md
|
@ -502,10 +502,13 @@ use_cpu: false
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> We commend using Accelerate for LoRA tuning.
|
||||||
|
|
||||||
#### Use DeepSpeed
|
#### Use DeepSpeed
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
--deepspeed ds_config.json \
|
--deepspeed ds_config.json \
|
||||||
... # arguments (same as above)
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
@ -522,25 +525,32 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": "auto",
|
"enabled": "auto",
|
||||||
"loss_scale": 0,
|
"loss_scale": 0,
|
||||||
"initial_scale_power": 16,
|
|
||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
"allgather_bucket_size": 5e8,
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": true,
|
||||||
"reduce_scatter": true,
|
"reduce_scatter": true,
|
||||||
"reduce_bucket_size": 5e8,
|
"reduce_bucket_size": 5e8,
|
||||||
"overlap_comm": false,
|
"contiguous_gradients": true,
|
||||||
"contiguous_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Refer to [examples](examples) for more training scripts.
|
||||||
|
|
||||||
### Merge LoRA weights and export model
|
### Merge LoRA weights and export model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
18
README_zh.md
18
README_zh.md
|
@ -501,10 +501,13 @@ use_cpu: false
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 我们推荐使用 Accelerate 进行 LoRA 训练。
|
||||||
|
|
||||||
#### 使用 DeepSpeed
|
#### 使用 DeepSpeed
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
--deepspeed ds_config.json \
|
--deepspeed ds_config.json \
|
||||||
... # 参数同上
|
... # 参数同上
|
||||||
```
|
```
|
||||||
|
@ -521,25 +524,32 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": "auto",
|
"enabled": "auto",
|
||||||
"loss_scale": 0,
|
"loss_scale": 0,
|
||||||
"initial_scale_power": 16,
|
|
||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
"allgather_bucket_size": 5e8,
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": true,
|
||||||
"reduce_scatter": true,
|
"reduce_scatter": true,
|
||||||
"reduce_bucket_size": 5e8,
|
"reduce_bucket_size": 5e8,
|
||||||
"overlap_comm": false,
|
"contiguous_gradients": true,
|
||||||
"contiguous_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 更多训练脚本请查看 [examples](examples)。
|
||||||
|
|
||||||
### 合并 LoRA 权重并导出模型
|
### 合并 LoRA 权重并导出模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -28,5 +28,6 @@ known-third-party = [
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
indent-style = "space"
|
indent-style = "space"
|
||||||
|
docstring-code-format = true
|
||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
|
|
|
@ -75,8 +75,7 @@ class Formatter(ABC):
|
||||||
tool_format: Literal["default"] = "default"
|
tool_format: Literal["default"] = "default"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
...
|
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -14,6 +14,7 @@ from transformers.utils import (
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
@ -56,6 +57,17 @@ class AverageMeter:
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
def check_dependencies() -> None:
|
||||||
|
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
|
||||||
|
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
|
else:
|
||||||
|
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||||
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
|
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||||
|
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||||
|
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Returns the number of trainable parameters and number of all parameters in the model.
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
|
|
|
@ -173,10 +173,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
disable_version_checking: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether or not to disable version checking."},
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
|
|
|
@ -7,7 +7,6 @@ import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_unsloth_available
|
from ..extras.packages import is_unsloth_available
|
||||||
|
@ -29,17 +28,6 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
|
||||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
|
||||||
|
|
||||||
def _check_dependencies(disabled: bool) -> None:
|
|
||||||
if disabled:
|
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
|
||||||
else:
|
|
||||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
|
||||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
|
||||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
|
||||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
|
@ -152,7 +140,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
raise ValueError("Unsloth does not support DoRA.")
|
raise ValueError("Unsloth does not support DoRA.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
|
@ -249,7 +236,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
@ -262,7 +248,6 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
|
||||||
model_args.aqlm_optimization = True
|
model_args.aqlm_optimization = True
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
|
|
|
@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
from ..extras.misc import check_dependencies, count_parameters, get_current_device, try_download_model_from_ms
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
from .utils import load_valuehead_params, register_autoclass
|
from .utils import load_valuehead_params, register_autoclass
|
||||||
|
@ -20,6 +20,9 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
check_dependencies()
|
||||||
|
|
||||||
|
|
||||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
|
|
Loading…
Reference in New Issue