fix #1789
This commit is contained in:
parent
ebee4f6a2a
commit
4571068e1e
|
@ -457,7 +457,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
|
|
|
@ -457,7 +457,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
@ -46,16 +45,11 @@ class Evaluator:
|
||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
if "token" in inspect.signature(cached_file).parameters:
|
|
||||||
kwargs = {"token": self.model_args.hf_hub_token}
|
|
||||||
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
|
||||||
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
|
|
||||||
|
|
||||||
mapping = cached_file(
|
mapping = cached_file(
|
||||||
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
filename="mapping.json",
|
filename="mapping.json",
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
**kwargs
|
token=self.model_args.hf_hub_token
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
import torch
|
||||||
|
from typing import TYPE_CHECKING, Dict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import PreTrainedModel, TrainerCallback
|
from transformers import PreTrainedModel, TrainerCallback
|
||||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
from llmtuner.extras.constants import LOG_FILE_NAME, V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
@ -20,31 +22,66 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _save_model_with_valuehead(
|
def _fix_valuehead_checkpoint(
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
safe_serialization: bool
|
safe_serialization: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
r"""
|
||||||
model.pretrained_model.config.save_pretrained(output_dir)
|
The model is already unwrapped.
|
||||||
if model.pretrained_model.can_generate():
|
|
||||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
if getattr(model, "is_peft_model", False):
|
There are three cases:
|
||||||
model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
|
||||||
|
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||||
|
"""
|
||||||
|
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||||
|
return
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
else:
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||||
|
|
||||||
|
decoder_state_dict = {}
|
||||||
|
v_head_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("v_head."):
|
||||||
|
v_head_state_dict[name] = param
|
||||||
|
else:
|
||||||
|
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
|
model.pretrained_model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=decoder_state_dict or None,
|
||||||
|
safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
|
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
_save_model_with_valuehead(
|
_fix_valuehead_checkpoint(
|
||||||
model=unwrap_model(kwargs.pop("model")),
|
model=kwargs.pop("model"),
|
||||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||||
safe_serialization=args.save_safetensors
|
safe_serialization=args.save_safetensors
|
||||||
)
|
)
|
||||||
|
@ -54,10 +91,8 @@ class SavePeftModelCallback(TrainerCallback):
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
_save_model_with_valuehead(
|
_fix_valuehead_checkpoint(
|
||||||
model=unwrap_model(kwargs.pop("model")),
|
model=kwargs.pop("model"), output_dir=args.output_dir, safe_serialization=args.save_safetensors
|
||||||
output_dir=args.output_dir,
|
|
||||||
safe_serialization=args.save_safetensors
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,10 @@ TRAINING_STAGES = {
|
||||||
"Pre-Training": "pt"
|
"Pre-Training": "pt"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
V_HEAD_WEIGHTS_NAME = "v_head.bin"
|
||||||
|
|
||||||
|
V_HEAD_SAFE_WEIGHTS_NAME = "v_head.safetensors"
|
||||||
|
|
||||||
class DownloadSource(str, Enum):
|
class DownloadSource(str, Enum):
|
||||||
DEFAULT = "hf"
|
DEFAULT = "hf"
|
||||||
MODELSCOPE = "ms"
|
MODELSCOPE = "ms"
|
||||||
|
|
|
@ -3,8 +3,8 @@ import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import get_current_device
|
from llmtuner.extras.misc import get_current_device
|
||||||
|
|
||||||
|
@ -103,22 +103,20 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||||
return {
|
return {key: f.get_tensor(key) for key in f.keys()}
|
||||||
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
|
||||||
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
|
|
||||||
}
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||||
return torch.load(vhead_file, map_location="cpu")
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||||
|
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||||
|
@ -60,7 +61,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.accelerator.state, "deepspeed_plugin"
|
self.accelerator.state, "deepspeed_plugin"
|
||||||
)
|
)
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
@ -369,9 +370,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
" use zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
self._save(output_dir, state_dict={})
|
self._save(output_dir, state_dict={})
|
||||||
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
|
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||||
file = os.path.join(output_dir, filename)
|
self.model.save_checkpoint(output_dir)
|
||||||
if os.path.isfile(file):
|
|
||||||
os.remove(file)
|
|
||||||
|
|
||||||
self.model.save_checkpoint(output_dir) # wrapped model
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from transformers import DataCollatorWithPadding
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
from llmtuner.train.utils import create_ref_model, create_reward_model
|
from llmtuner.train.utils import create_ref_model, create_reward_model
|
||||||
|
@ -79,7 +79,7 @@ def run_ppo(
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
reward_model=reward_model,
|
reward_model=reward_model,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
|
@ -40,7 +40,7 @@ def run_rm(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue