update examples
This commit is contained in:
parent
09735ed30c
commit
cce52351b5
|
@ -3,41 +3,42 @@ We provide diverse examples about fine-tuning LLMs.
|
||||||
```
|
```
|
||||||
examples/
|
examples/
|
||||||
├── lora_single_gpu/
|
├── lora_single_gpu/
|
||||||
│ ├── pretrain.sh: Do pre-training
|
│ ├── pretrain.sh: Do pre-training using LoRA
|
||||||
│ ├── sft.sh: Do supervised fine-tuning
|
│ ├── sft.sh: Do supervised fine-tuning using LoRA
|
||||||
│ ├── reward.sh: Do reward modeling
|
│ ├── reward.sh: Do reward modeling using LoRA
|
||||||
│ ├── ppo.sh: Do PPO training
|
│ ├── ppo.sh: Do PPO training using LoRA
|
||||||
│ ├── dpo.sh: Do DPO training
|
│ ├── dpo.sh: Do DPO training using LoRA
|
||||||
│ ├── orpo.sh: Do ORPO training
|
│ ├── orpo.sh: Do ORPO training using LoRA
|
||||||
│ ├── prepare.sh: Save tokenized dataset
|
│ ├── prepare.sh: Save tokenized dataset
|
||||||
│ └── predict.sh: Do batch predict
|
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
|
||||||
├── qlora_single_gpu/
|
├── qlora_single_gpu/
|
||||||
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models
|
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
|
||||||
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models
|
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
|
||||||
│ ├── awq.sh: Fine-tune 4-bit AWQ models
|
│ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
|
||||||
│ └── aqlm.sh: Fine-tune 2-bit AQLM models
|
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
|
||||||
├── lora_multi_gpu/
|
├── lora_multi_gpu/
|
||||||
│ ├── single_node.sh: Fine-tune model with Accelerate on single node
|
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
|
||||||
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes
|
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
|
||||||
├── full_multi_gpu/
|
├── full_multi_gpu/
|
||||||
│ ├── single_node.sh: Fine-tune model with DeepSpeed on single node
|
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
|
||||||
│ └── multi_node.sh: Fine-tune model with DeepSpeed on multiple nodes
|
│ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
|
||||||
|
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after full tuning
|
||||||
├── merge_lora/
|
├── merge_lora/
|
||||||
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
|
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
|
||||||
│ └── quantize.sh: Quantize fine-tuned model with AutoGPTQ
|
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
|
||||||
├── inference/
|
├── inference/
|
||||||
│ ├── cli_demo.sh: Launch a command line interface
|
│ ├── cli_demo.sh: Launch a command line interface with LoRA adapters
|
||||||
│ ├── api_demo.sh: Launch an OpenAI-style API
|
│ ├── api_demo.sh: Launch an OpenAI-style API with LoRA adapters
|
||||||
│ ├── web_demo.sh: Launch a web interface
|
│ ├── web_demo.sh: Launch a web interface with LoRA adapters
|
||||||
│ └── evaluate.sh: Evaluate model on the MMLU benchmark
|
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
|
||||||
└── extras/
|
└── extras/
|
||||||
├── galore/
|
├── galore/
|
||||||
│ └── sft.sh: Fine-tune model with GaLore
|
│ └── sft.sh: Fine-tune model with GaLore
|
||||||
├── loraplus/
|
├── loraplus/
|
||||||
│ └── sft.sh: Fine-tune model with LoRA+
|
│ └── sft.sh: Fine-tune model using LoRA+
|
||||||
├── llama_pro/
|
├── llama_pro/
|
||||||
│ ├── expand.sh: Expand layers in the model
|
│ ├── expand.sh: Expand layers in the model
|
||||||
│ └── sft.sh: Fine-tune expanded model
|
│ └── sft.sh: Fine-tune the expanded model
|
||||||
└── fsdp_qlora/
|
└── fsdp_qlora/
|
||||||
└── sft.sh: Fine-tune quantized model with FSDP
|
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,36 +1,36 @@
|
||||||
我们提供了多样化的示例脚本。
|
我们提供了多样化的大模型微调示例脚本。
|
||||||
|
|
||||||
```
|
```
|
||||||
examples/
|
examples/
|
||||||
├── lora_single_gpu/
|
├── lora_single_gpu/
|
||||||
│ ├── pretrain.sh: 进行预训练
|
│ ├── pretrain.sh: 基于 LoRA 进行预训练
|
||||||
│ ├── sft.sh: 进行指令监督微调
|
│ ├── sft.sh: 基于 LoRA 进行指令监督微调
|
||||||
│ ├── reward.sh: 进行奖励模型训练
|
│ ├── reward.sh: 基于 LoRA 进行奖励模型训练
|
||||||
│ ├── ppo.sh: 进行 PPO 训练
|
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
|
||||||
│ ├── dpo.sh: 进行 DPO 训练
|
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
|
||||||
│ ├── orpo.sh: 进行 ORPO 训练
|
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
|
||||||
│ ├── prepare.sh: 保存预处理后的数据集
|
│ ├── prepare.sh: 保存预处理后的数据集
|
||||||
│ └── predict.sh: 进行批量预测
|
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
├── qlora_single_gpu/
|
├── qlora_single_gpu/
|
||||||
│ ├── bitsandbytes.sh: 微调 4/8 比特 BNB 模型
|
│ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
|
||||||
│ ├── gptq.sh: 微调 4/8 比特 GPTQ 模型
|
│ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
|
||||||
│ ├── awq.sh: 微调 4 比特 AWQ 模型
|
│ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
|
||||||
│ └── aqlm.sh: 微调 2 比特 AQLM 模型
|
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
|
||||||
├── lora_multi_gpu/
|
├── lora_multi_gpu/
|
||||||
│ ├── single_node.sh: 使用 Accelerate 进行单节点训练
|
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
|
||||||
│ └── multi_node.sh: 使用 Accelerate 进行多节点训练
|
│ └── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
|
||||||
├── full_multi_gpu/
|
├── full_multi_gpu/
|
||||||
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点训练
|
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
|
||||||
│ └── multi_node.sh: 使用 DeepSpeed 进行多节点训练
|
│ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
|
||||||
| └── predict.sh: 使用单卡做全参批量预测
|
│ └── predict.sh: 基于全量训练进行批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
├── merge_lora/
|
├── merge_lora/
|
||||||
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
|
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
|
||||||
│ └── quantize.sh: 使用 AutoGPTQ 量化模型
|
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
|
||||||
├── inference/
|
├── inference/
|
||||||
│ ├── cli_demo.sh: 启动命令行推理接口
|
│ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
|
||||||
│ ├── api_demo.sh: 启动 OpenAI 风格 API
|
│ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
|
||||||
│ ├── web_demo.sh: 启动浏览器推理接口
|
│ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
|
||||||
│ └── evaluate.sh: 在 MMLU 数据集上评测模型
|
│ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
|
||||||
└── extras/
|
└── extras/
|
||||||
├── galore/
|
├── galore/
|
||||||
│ └── sft.sh: 使用 GaLore 训练模型
|
│ └── sft.sh: 使用 GaLore 训练模型
|
||||||
|
@ -40,5 +40,5 @@ examples/
|
||||||
│ ├── expand.sh: 扩展模型中的层
|
│ ├── expand.sh: 扩展模型中的层
|
||||||
│ └── sft.sh: 训练扩展后的模型
|
│ └── sft.sh: 训练扩展后的模型
|
||||||
└── fsdp_qlora/
|
└── fsdp_qlora/
|
||||||
└── sft.sh: 使用 FSDP 微调量化模型
|
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
|
||||||
```
|
```
|
||||||
|
|
|
@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
|
--loraplus_lr_ratio 16.0 \
|
||||||
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
|
@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
--max_samples 3000 \
|
--max_samples 3000 \
|
||||||
--val_size 0.1 \
|
--val_size 0.1 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16 \
|
--fp16
|
||||||
--loraplus_lr_ratio 16.0
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
|
||||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
--dataset_dir ../../data \
|
--dataset_dir ../../data \
|
||||||
--template default \
|
--template default \
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# DO NOT use quantized model or quantization_bit when merging lora weights
|
# DO NOT use quantized model or quantization_bit when merging lora weights
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES= python ../../src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
--template default \
|
--template default \
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_vllm_available
|
from ..extras.packages import is_vllm_available
|
||||||
|
@ -25,7 +23,6 @@ class VllmEngine(BaseEngine):
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model=model_args.model_name_or_path,
|
model=model_args.model_name_or_path,
|
||||||
|
|
|
@ -49,10 +49,6 @@ def is_starlette_available():
|
||||||
return _is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
|
|
||||||
def is_unsloth_available():
|
|
||||||
return _is_package_available("unsloth")
|
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
|
@ -8,10 +8,10 @@ import transformers
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import check_dependencies, get_current_device
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
from ..extras.packages import is_unsloth_available
|
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
|
@ -74,6 +74,26 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_extra_dependencies(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||||
|
) -> None:
|
||||||
|
if model_args.use_unsloth:
|
||||||
|
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||||
|
|
||||||
|
if model_args.infer_backend == "vllm":
|
||||||
|
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||||
|
|
||||||
|
if finetuning_args.use_galore:
|
||||||
|
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||||
|
|
||||||
|
if training_args is not None and training_args.predict_with_generate:
|
||||||
|
require_version("jieba", "To fix: pip install jieba")
|
||||||
|
require_version("nltk", "To fix: pip install nltk")
|
||||||
|
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
return _parse_args(parser, args)
|
return _parse_args(parser, args)
|
||||||
|
@ -131,9 +151,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if training_args.do_train and training_args.predict_with_generate:
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
|
||||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
|
|
||||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
|
||||||
|
|
||||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support DoRA.")
|
raise ValueError("Unsloth does not support DoRA.")
|
||||||
|
|
||||||
|
@ -158,6 +175,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
|
@ -277,6 +295,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
if model_args.export_dir is not None:
|
if model_args.export_dir is not None:
|
||||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||||
|
@ -298,6 +317,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
model_args.device_map = "auto"
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,9 @@ def load_model(
|
||||||
logger.warning("Unsloth does not support loading adapters.")
|
logger.warning("Unsloth does not support loading adapters.")
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
init_kwargs["config"] = config
|
||||||
|
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||||
|
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||||
|
|
||||||
patch_model(model, tokenizer, model_args, is_trainable)
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
register_autoclass(config, model, tokenizer)
|
register_autoclass(config, model, tokenizer)
|
||||||
|
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
|
@ -33,10 +32,6 @@ class ComputeMetrics:
|
||||||
r"""
|
r"""
|
||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
require_version("jieba", "To fix: pip install jieba")
|
|
||||||
require_version("nltk", "To fix: pip install nltk")
|
|
||||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
|
||||||
|
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ from transformers import Trainer
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_galore_available
|
from ..extras.packages import is_galore_available
|
||||||
|
@ -168,8 +167,6 @@ def _create_galore_optimizer(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
|
||||||
|
|
||||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||||
galore_targets = find_all_linear_modules(model)
|
galore_targets = find_all_linear_modules(model)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue