This commit is contained in:
hiyouga 2024-01-04 23:19:08 +08:00
parent cc275abe09
commit 33f2c0d4f8
2 changed files with 6 additions and 1 deletions

View File

@ -1,5 +1,6 @@
import torch
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
@ -71,6 +72,10 @@ def init_adapter(
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]

View File

@ -3,7 +3,7 @@ import math
import torch
import random
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase