diff --git a/README.md b/README.md index 7e5359b6..4f0ff5b6 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t ### Dependence Installation (optional) ```bash +git lfs install git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_etuning python=3.10 conda activate llama_etuning diff --git a/README_zh.md b/README_zh.md index 2e44197a..88246178 100644 --- a/README_zh.md +++ b/README_zh.md @@ -128,6 +128,7 @@ huggingface-cli login ### 环境搭建(可跳过) ```bash +git lfs install git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_etuning python=3.10 conda activate llama_etuning diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py index 4b4f647e..0b1a68f6 100644 --- a/src/llmtuner/extras/logging.py +++ b/src/llmtuner/extras/logging.py @@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler): self.log += "\n\n" -def get_logger(name: str) -> logging.Logger: +def reset_logging(): + r""" + Removes basic config of root logger + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) + +def get_logger(name: str) -> logging.Logger: formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 3997c7ef..d4ce6e50 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizerBase from trl import AutoModelForCausalLMWithValueHead -from llmtuner.extras.logging import get_logger +from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.misc import count_parameters, prepare_model_for_training from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.hparams import FinetuningArguments @@ -95,7 +95,10 @@ def load_model_and_tokenizer( is_mergeable = False logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None: + if ( + model_args.quantization_bit is not None + or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled()) + ): config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": @@ -126,6 +129,7 @@ def load_model_and_tokenizer( if stage == "rm" or stage == "ppo": # add value head model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + reset_logging() if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index f9f38058..d872afcc 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -85,6 +85,9 @@ def get_train_args( assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ "Streaming mode does not support evaluation currently." + assert not (general_args.stage == "ppo" and data_args.streaming), \ + "Streaming mode does not suppport PPO training currently." + if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." @@ -107,8 +110,8 @@ def get_train_args( training_args.ddp_find_unused_parameters = False if data_args.max_samples is not None and data_args.streaming: - logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.") - data_args.streaming = False + logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") + data_args.max_samples = None if data_args.dev_ratio > 1e-6 and data_args.streaming: logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index 805a3553..3d6a2d4b 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -47,20 +47,19 @@ class PeftTrainer(Seq2SeqTrainer): logger.info(f"Saving model checkpoint to {output_dir}") model = unwrap_model(self.model) - state_dict = state_dict or get_state_dict(model) if isinstance(model, PreTrainedModelWrapper): - model_params, v_head_params = {}, {} - for name in state_dict.keys(): - if name.startswith("pretrained_model."): - model_params[name.replace("pretrained_model.", "")] = state_dict[name] - elif name.startswith("v_head."): - v_head_params[name.replace("v_head.", "")] = state_dict[name] + # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200 + model_state_dict = state_dict or model.state_dict() + v_head_state_dict = { + name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach() + for name in model_state_dict.keys() if name.startswith("v_head.") + } - torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) - state_dict = model_params + torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) model = model.pretrained_model + state_dict = state_dict or get_state_dict(model) if isinstance(model, (PeftModel, PreTrainedModel)): model.config.use_cache = True model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index f28cb93f..6c7769ef 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -11,7 +11,7 @@ from trl import PPOTrainer from trl.core import LengthSampler from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import AverageMeter, get_logits_processor +from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model @@ -29,6 +29,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): r""" Inherits PPOTrainer. """ + def __init__( self, training_args: "Seq2SeqTrainingArguments", @@ -70,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") - logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") + logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") # Keyword arguments for `model.generate` gen_kwargs = {