support distributed quantized training
This commit is contained in:
parent
3d8d5ee5d5
commit
4eb17bcf6c
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||||
|
|
||||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ from .config import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .other import (
|
from .other import (
|
||||||
get_logger,
|
get_main_logger,
|
||||||
load_trainable_params,
|
load_trainable_params,
|
||||||
load_valuehead_params,
|
load_valuehead_params,
|
||||||
print_trainable_params,
|
print_trainable_params,
|
||||||
|
@ -53,7 +53,7 @@ require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_main_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _init_adapter(
|
def _init_adapter(
|
||||||
|
@ -190,9 +190,10 @@ def load_pretrained(
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
if model_args.quantization_bit is not None or (not is_trainable): # automatically load in CUDA
|
if not is_trainable:
|
||||||
config_kwargs["device_map"] = "auto"
|
config_kwargs["device_map"] = "auto"
|
||||||
|
|
||||||
# Load and prepare pretrained models (without valuehead).
|
# Load and prepare pretrained models (without valuehead).
|
||||||
|
@ -288,7 +289,7 @@ def prepare_args(
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
, main_process_only=False)
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
|
|
|
@ -10,6 +10,8 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.generation.utils import LogitsProcessorList
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from peft.utils.other import WEIGHTS_NAME
|
from peft.utils.other import WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,17 +20,16 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
level=logging.INFO,
|
|
||||||
handlers=[logging.StreamHandler(sys.stdout)]
|
handlers=[logging.StreamHandler(sys.stdout)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
def get_main_logger(name: str) -> logging.Logger:
|
||||||
return logging.getLogger(name)
|
return get_logger(name, log_level="INFO")
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
|
@ -57,7 +58,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
scores.zero_()
|
scores.zero_()
|
||||||
scores[:, 0] = 1.0
|
scores[..., 0] = 1.0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,9 @@ from .data_collator import DynamicDataCollatorWithPadding
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer
|
from .peft_trainer import PeftTrainer
|
||||||
|
|
||||||
from .other import get_logger
|
from .other import get_main_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_main_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
|
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from peft.utils.other import WEIGHTS_NAME
|
||||||
from .config import FinetuningArguments
|
from .config import FinetuningArguments
|
||||||
|
|
||||||
from .other import (
|
from .other import (
|
||||||
get_logger,
|
get_main_logger,
|
||||||
get_state_dict,
|
get_state_dict,
|
||||||
load_trainable_params,
|
load_trainable_params,
|
||||||
load_valuehead_params,
|
load_valuehead_params,
|
||||||
|
@ -30,7 +30,7 @@ from .other import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_main_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
|
@ -16,12 +16,12 @@ from .config import FinetuningArguments
|
||||||
|
|
||||||
from .other import (
|
from .other import (
|
||||||
AverageMeter,
|
AverageMeter,
|
||||||
get_logger,
|
get_main_logger,
|
||||||
get_logits_processor
|
get_logits_processor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_main_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||||
|
|
|
@ -13,10 +13,10 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer
|
from .peft_trainer import PeftTrainer
|
||||||
|
|
||||||
from .other import get_logger, IGNORE_INDEX
|
from .other import get_main_logger, IGNORE_INDEX
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_main_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
Loading…
Reference in New Issue