fix RM save model
This commit is contained in:
parent
82e793ddb4
commit
ac88ce5233
|
@ -128,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
|
|||
### Dependence Installation (optional)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
|
|
|
@ -128,6 +128,7 @@ huggingface-cli login
|
|||
### 环境搭建(可跳过)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
|
|
|
@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
|
|||
self.log += "\n\n"
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
def reset_logging():
|
||||
r"""
|
||||
Removes basic config of root logger
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
|
|
|
@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
|||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
@ -95,7 +95,10 @@ def load_model_and_tokenizer(
|
|||
is_mergeable = False
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
|
||||
if (
|
||||
model_args.quantization_bit is not None
|
||||
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
|
||||
):
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
|
@ -126,6 +129,7 @@ def load_model_and_tokenizer(
|
|||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
|
|
|
@ -85,6 +85,9 @@ def get_train_args(
|
|||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
||||
"Streaming mode does not support evaluation currently."
|
||||
|
||||
assert not (general_args.stage == "ppo" and data_args.streaming), \
|
||||
"Streaming mode does not suppport PPO training currently."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
|
@ -107,8 +110,8 @@ def get_train_args(
|
|||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
|
||||
data_args.streaming = False
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
|
||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||
|
|
|
@ -47,20 +47,19 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model_params, v_head_params = {}, {}
|
||||
for name in state_dict.keys():
|
||||
if name.startswith("pretrained_model."):
|
||||
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
|
||||
elif name.startswith("v_head."):
|
||||
v_head_params[name.replace("v_head.", "")] = state_dict[name]
|
||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||
model_state_dict = state_dict or model.state_dict()
|
||||
v_head_state_dict = {
|
||||
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
|
||||
for name in model_state_dict.keys() if name.startswith("v_head.")
|
||||
}
|
||||
|
||||
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
state_dict = model_params
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
model = model.pretrained_model
|
||||
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||
|
|
|
@ -11,7 +11,7 @@ from trl import PPOTrainer
|
|||
from trl.core import LengthSampler
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
@ -29,6 +29,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
|
@ -70,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
|
|
Loading…
Reference in New Issue