This commit is contained in:
hiyouga 2024-01-09 18:31:27 +08:00
parent ebee4f6a2a
commit 4571068e1e
9 changed files with 78 additions and 50 deletions

View File

@ -457,7 +457,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
"loss_scale_window": 1000, "loss_scale_window": 1000,
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"allgather_partitions": true, "allgather_partitions": true,

View File

@ -457,7 +457,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
"loss_scale_window": 1000, "loss_scale_window": 1000,
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"allgather_partitions": true, "allgather_partitions": true,

View File

@ -3,7 +3,6 @@
import os import os
import json import json
import torch import torch
import inspect
import tiktoken import tiktoken
import numpy as np import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -46,16 +45,11 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None: def eval(self) -> None:
if "token" in inspect.signature(cached_file).parameters:
kwargs = {"token": self.model_args.hf_hub_token}
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
mapping = cached_file( mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task), path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json", filename="mapping.json",
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
**kwargs token=self.model_args.hf_hub_token
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, "r", encoding="utf-8") as f:

View File

@ -1,17 +1,19 @@
import os import os
import json import json
import time import time
from typing import TYPE_CHECKING import torch
from typing import TYPE_CHECKING, Dict
from datetime import timedelta from datetime import timedelta
from transformers import PreTrainedModel, TrainerCallback from transformers import PreTrainedModel, TrainerCallback
from transformers.modeling_utils import custom_object_save, unwrap_model from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from peft import PeftModel from peft import PeftModel
from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.constants import LOG_FILE_NAME, V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -20,31 +22,66 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _save_model_with_valuehead( def _fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
output_dir: str, output_dir: str,
safe_serialization: bool safe_serialization: bool
) -> None: ) -> None:
if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): r"""
model.pretrained_model.config.save_pretrained(output_dir) The model is already unwrapped.
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False): There are three cases:
model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization) 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config) 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir,
state_dict=decoder_state_dict or None,
safe_serialization=safe_serialization
)
if safe_serialization:
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
class SavePeftModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
_save_model_with_valuehead( _fix_valuehead_checkpoint(
model=unwrap_model(kwargs.pop("model")), model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors safe_serialization=args.save_safetensors
) )
@ -54,10 +91,8 @@ class SavePeftModelCallback(TrainerCallback):
Event called at the end of training. Event called at the end of training.
""" """
if args.should_save: if args.should_save:
_save_model_with_valuehead( _fix_valuehead_checkpoint(
model=unwrap_model(kwargs.pop("model")), model=kwargs.pop("model"), output_dir=args.output_dir, safe_serialization=args.save_safetensors
output_dir=args.output_dir,
safe_serialization=args.save_safetensors
) )

View File

@ -40,6 +40,10 @@ TRAINING_STAGES = {
"Pre-Training": "pt" "Pre-Training": "pt"
} }
V_HEAD_WEIGHTS_NAME = "v_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "v_head.safetensors"
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"

View File

@ -3,8 +3,8 @@ import inspect
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Dict, List
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.utils import cached_file from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device from llmtuner.extras.misc import get_current_device
@ -103,22 +103,20 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
try: try:
from safetensors import safe_open from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f: with safe_open(vhead_file, framework="pt", device="cpu") as f:
return { return {key: f.get_tensor(key) for key in f.keys()}
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
}
except Exception as err: except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err))) logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try: try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu") return torch.load(vhead_file, map_location="cpu")
except Exception as err: except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err))) logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)) logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None return None

View File

@ -8,11 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
@ -60,7 +61,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.accelerator.state, "deepspeed_plugin" self.accelerator.state, "deepspeed_plugin"
) )
self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
@ -369,9 +370,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
" use zero_to_fp32.py to recover weights" " use zero_to_fp32.py to recover weights"
) )
self._save(output_dir, state_dict={}) self._save(output_dir, state_dict={})
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
file = os.path.join(output_dir, filename) self.model.save_checkpoint(output_dir)
if os.path.isfile(file):
os.remove(file)
self.model.save_checkpoint(output_dir) # wrapped model

View File

@ -8,7 +8,7 @@ from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model from llmtuner.train.utils import create_ref_model, create_reward_model
@ -79,7 +79,7 @@ def run_ppo(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model, reward_model=reward_model,
config=ppo_config, config=ppo_config,
model=model, model=model,

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
@ -40,7 +40,7 @@ def run_rm(
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks + [SavePeftModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args)
) )