fix RM save model

This commit is contained in:
hiyouga 2023-08-01 11:56:17 +08:00
parent 82e793ddb4
commit ac88ce5233
7 changed files with 33 additions and 16 deletions

View File

@ -128,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional) ### Dependence Installation (optional)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning

View File

@ -128,6 +128,7 @@ huggingface-cli login
### 环境搭建(可跳过) ### 环境搭建(可跳过)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning

View File

@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n" self.log += "\n\n"
def get_logger(name: str) -> logging.Logger: def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S" datefmt="%m/%d/%Y %H:%M:%S"

View File

@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments
@ -95,7 +95,10 @@ def load_model_and_tokenizer(
is_mergeable = False is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None: if (
model_args.quantization_bit is not None
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
):
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
@ -126,6 +129,7 @@ def load_model_and_tokenizer(
if stage == "rm" or stage == "ppo": # add value head if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")

View File

@ -85,6 +85,9 @@ def get_train_args(
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently." "Streaming mode does not support evaluation currently."
assert not (general_args.stage == "ppo" and data_args.streaming), \
"Streaming mode does not suppport PPO training currently."
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
@ -107,8 +110,8 @@ def get_train_args(
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming: if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.") logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.streaming = False data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming: if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")

View File

@ -47,20 +47,19 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model) model = unwrap_model(self.model)
state_dict = state_dict or get_state_dict(model)
if isinstance(model, PreTrainedModelWrapper): if isinstance(model, PreTrainedModelWrapper):
model_params, v_head_params = {}, {} # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
for name in state_dict.keys(): model_state_dict = state_dict or model.state_dict()
if name.startswith("pretrained_model."): v_head_state_dict = {
model_params[name.replace("pretrained_model.", "")] = state_dict[name] name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
elif name.startswith("v_head."): for name in model_state_dict.keys() if name.startswith("v_head.")
v_head_params[name.replace("v_head.", "")] = state_dict[name] }
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
state_dict = model_params
model = model.pretrained_model model = model.pretrained_model
state_dict = state_dict or get_state_dict(model)
if isinstance(model, (PeftModel, PreTrainedModel)): if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors) model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)

View File

@ -11,7 +11,7 @@ from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
@ -29,6 +29,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r""" r"""
Inherits PPOTrainer. Inherits PPOTrainer.
""" """
def __init__( def __init__(
self, self,
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@ -70,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = {