fix RM save model
This commit is contained in:
parent
82e793ddb4
commit
ac88ce5233
|
@ -128,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
|
||||||
### Dependence Installation (optional)
|
### Dependence Installation (optional)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
git lfs install
|
||||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||||
conda create -n llama_etuning python=3.10
|
conda create -n llama_etuning python=3.10
|
||||||
conda activate llama_etuning
|
conda activate llama_etuning
|
||||||
|
|
|
@ -128,6 +128,7 @@ huggingface-cli login
|
||||||
### 环境搭建(可跳过)
|
### 环境搭建(可跳过)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
git lfs install
|
||||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||||
conda create -n llama_etuning python=3.10
|
conda create -n llama_etuning python=3.10
|
||||||
conda activate llama_etuning
|
conda activate llama_etuning
|
||||||
|
|
|
@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
|
||||||
self.log += "\n\n"
|
self.log += "\n\n"
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
def reset_logging():
|
||||||
|
r"""
|
||||||
|
Removes basic config of root logger
|
||||||
|
"""
|
||||||
|
root = logging.getLogger()
|
||||||
|
list(map(root.removeHandler, root.handlers))
|
||||||
|
list(map(root.removeFilter, root.filters))
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S"
|
datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
|
|
|
@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
@ -95,7 +95,10 @@ def load_model_and_tokenizer(
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
|
if (
|
||||||
|
model_args.quantization_bit is not None
|
||||||
|
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
|
||||||
|
):
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||||
|
@ -126,6 +129,7 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
if stage == "rm" or stage == "ppo": # add value head
|
if stage == "rm" or stage == "ppo": # add value head
|
||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
reset_logging()
|
||||||
|
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||||
|
|
|
@ -85,6 +85,9 @@ def get_train_args(
|
||||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
||||||
"Streaming mode does not support evaluation currently."
|
"Streaming mode does not support evaluation currently."
|
||||||
|
|
||||||
|
assert not (general_args.stage == "ppo" and data_args.streaming), \
|
||||||
|
"Streaming mode does not suppport PPO training currently."
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||||
|
@ -107,8 +110,8 @@ def get_train_args(
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
if data_args.max_samples is not None and data_args.streaming:
|
if data_args.max_samples is not None and data_args.streaming:
|
||||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
|
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||||
data_args.streaming = False
|
data_args.max_samples = None
|
||||||
|
|
||||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||||
|
|
|
@ -47,20 +47,19 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
model = unwrap_model(self.model)
|
model = unwrap_model(self.model)
|
||||||
state_dict = state_dict or get_state_dict(model)
|
|
||||||
|
|
||||||
if isinstance(model, PreTrainedModelWrapper):
|
if isinstance(model, PreTrainedModelWrapper):
|
||||||
model_params, v_head_params = {}, {}
|
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||||
for name in state_dict.keys():
|
model_state_dict = state_dict or model.state_dict()
|
||||||
if name.startswith("pretrained_model."):
|
v_head_state_dict = {
|
||||||
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
|
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
|
||||||
elif name.startswith("v_head."):
|
for name in model_state_dict.keys() if name.startswith("v_head.")
|
||||||
v_head_params[name.replace("v_head.", "")] = state_dict[name]
|
}
|
||||||
|
|
||||||
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||||
state_dict = model_params
|
|
||||||
model = model.pretrained_model
|
model = model.pretrained_model
|
||||||
|
|
||||||
|
state_dict = state_dict or get_state_dict(model)
|
||||||
if isinstance(model, (PeftModel, PreTrainedModel)):
|
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
|
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
|
@ -29,6 +29,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PPOTrainer.
|
Inherits PPOTrainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
@ -70,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||||
logger.info(f" Total optimization steps = {max_steps}")
|
logger.info(f" Total optimization steps = {max_steps}")
|
||||||
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
|
|
Loading…
Reference in New Issue