fix RM save model

This commit is contained in:
hiyouga 2023-08-01 11:56:17 +08:00
parent 82e793ddb4
commit ac88ce5233
7 changed files with 33 additions and 16 deletions

View File

@ -128,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional)
```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning

View File

@ -128,6 +128,7 @@ huggingface-cli login
### 环境搭建(可跳过)
```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning

View File

@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n"
def get_logger(name: str) -> logging.Logger:
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"

View File

@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
@ -95,7 +95,10 @@ def load_model_and_tokenizer(
is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
if (
model_args.quantization_bit is not None
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
):
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
@ -126,6 +129,7 @@ def load_model_and_tokenizer(
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")

View File

@ -85,6 +85,9 @@ def get_train_args(
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
assert not (general_args.stage == "ppo" and data_args.streaming), \
"Streaming mode does not suppport PPO training currently."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
@ -107,8 +110,8 @@ def get_train_args(
training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
data_args.streaming = False
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")

View File

@ -47,20 +47,19 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
state_dict = state_dict or get_state_dict(model)
if isinstance(model, PreTrainedModelWrapper):
model_params, v_head_params = {}, {}
for name in state_dict.keys():
if name.startswith("pretrained_model."):
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
elif name.startswith("v_head."):
v_head_params[name.replace("v_head.", "")] = state_dict[name]
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()
v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
state_dict = model_params
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model.pretrained_model
state_dict = state_dict or get_state_dict(model)
if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)

View File

@ -11,7 +11,7 @@ from trl import PPOTrainer
from trl.core import LengthSampler
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
@ -29,6 +29,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
training_args: "Seq2SeqTrainingArguments",
@ -70,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
gen_kwargs = {