refactor model_dtype, fix PPO trainer
This commit is contained in:
parent
5310e4d182
commit
2818af0b09
|
@ -3,6 +3,19 @@ import torch
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.utils import (
|
||||||
|
is_torch_bf16_cpu_available,
|
||||||
|
is_torch_bf16_gpu_available,
|
||||||
|
is_torch_cuda_available,
|
||||||
|
is_torch_npu_available
|
||||||
|
)
|
||||||
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
|
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
||||||
|
except ImportError:
|
||||||
|
_is_fp16_available = torch.cuda.is_available()
|
||||||
|
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
@ -49,7 +62,22 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||||
|
r"""
|
||||||
|
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||||
|
"""
|
||||||
|
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||||
|
return torch.bfloat16
|
||||||
|
elif _is_fp16_available:
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> LogitsProcessorList:
|
def get_logits_processor() -> LogitsProcessorList:
|
||||||
|
r"""
|
||||||
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
|
"""
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
|
@ -138,11 +138,11 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||||
query_states = query_states.to(torch.float16)
|
query_states = query_states.to(self.config.torch_dtype)
|
||||||
key_states = key_states.to(torch.float16)
|
key_states = key_states.to(self.config.torch_dtype)
|
||||||
value_states = value_states.to(torch.float16)
|
value_states = value_states.to(self.config.torch_dtype)
|
||||||
|
|
||||||
if getattr(self, "num_key_value_groups"):
|
if getattr(self, "num_key_value_groups", None):
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
|
|
@ -67,9 +67,9 @@ class ModelArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
)
|
)
|
||||||
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
|
upcast_layernorm: Optional[bool] = field(
|
||||||
default="auto",
|
default=False,
|
||||||
metadata={"help": "Data type of the layer norm weights."}
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -24,7 +24,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters
|
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
@ -86,11 +86,17 @@ def load_model_and_tokenizer(
|
||||||
if getattr(config, "model_type", None) == "chatglm":
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
|
|
||||||
|
# Set model dtype
|
||||||
|
if model_args.compute_dtype is not None:
|
||||||
|
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||||
|
else: # priority: bf16 > fp16 > fp32
|
||||||
|
optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
setattr(config, "torch_dtype", optim_dtype)
|
||||||
|
|
||||||
# Fix config (for Qwen)
|
# Fix config (for Qwen)
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||||
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
|
|
||||||
|
|
||||||
# Set RoPE scaling
|
# Set RoPE scaling
|
||||||
if model_args.rope_scaling is not None:
|
if model_args.rope_scaling is not None:
|
||||||
|
@ -131,9 +137,7 @@ def load_model_and_tokenizer(
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||||
LlamaPatches._prepare_decoder_attention_mask
|
|
||||||
)
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
elif getattr(config, "model_type", None) == "qwen":
|
elif getattr(config, "model_type", None) == "qwen":
|
||||||
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
||||||
|
@ -180,7 +184,6 @@ def load_model_and_tokenizer(
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=model_args.compute_dtype,
|
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
@ -203,7 +206,7 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
# Initialize adapters
|
# Initialize adapters
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
|
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
model = model.train() if is_trainable else model.eval()
|
model = model.train() if is_trainable else model.eval()
|
||||||
|
|
||||||
|
|
|
@ -8,16 +8,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
|
|
||||||
is_fp16_available = is_torch_cuda_available()
|
|
||||||
is_bf16_available = is_torch_bf16_gpu_available()
|
|
||||||
is_npu_available = is_torch_npu_available()
|
|
||||||
except ImportError:
|
|
||||||
is_fp16_available = torch.cuda.is_available()
|
|
||||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
|
||||||
is_npu_available = False
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.hparams import (
|
from llmtuner.hparams import (
|
||||||
ModelArguments,
|
ModelArguments,
|
||||||
|
@ -31,17 +21,6 @@ from llmtuner.hparams import (
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _infer_dtype() -> torch.dtype:
|
|
||||||
if is_npu_available:
|
|
||||||
return torch.float16
|
|
||||||
elif is_bf16_available:
|
|
||||||
return torch.bfloat16
|
|
||||||
elif is_fp16_available:
|
|
||||||
return torch.float16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
|
@ -178,12 +157,15 @@ def get_train_args(
|
||||||
if not finetuning_args.resume_lora_training:
|
if not finetuning_args.resume_lora_training:
|
||||||
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
logger.warning("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
# postprocess data_args
|
# postprocess data_args
|
||||||
if data_args.max_samples is not None and data_args.streaming:
|
if data_args.max_samples is not None and data_args.streaming:
|
||||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||||
|
@ -206,10 +188,9 @@ def get_train_args(
|
||||||
and os.path.isdir(training_args.output_dir)
|
and os.path.isdir(training_args.output_dir)
|
||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
|
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||||
|
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
|
@ -220,26 +201,7 @@ def get_train_args(
|
||||||
)
|
)
|
||||||
|
|
||||||
# postprocess model_args
|
# postprocess model_args
|
||||||
if training_args.bf16:
|
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||||
if not is_bf16_available:
|
|
||||||
raise ValueError("Current device does not support bf16 training.")
|
|
||||||
model_args.compute_dtype = torch.bfloat16
|
|
||||||
elif training_args.fp16:
|
|
||||||
model_args.compute_dtype = torch.float16
|
|
||||||
else:
|
|
||||||
model_args.compute_dtype = _infer_dtype()
|
|
||||||
|
|
||||||
if model_args.layernorm_dtype == "bf16":
|
|
||||||
if not is_bf16_available:
|
|
||||||
raise ValueError("Current device does not support bf16 type.")
|
|
||||||
model_args.layernorm_dtype = torch.bfloat16
|
|
||||||
elif model_args.layernorm_dtype == "fp16":
|
|
||||||
model_args.layernorm_dtype = torch.float16
|
|
||||||
elif model_args.layernorm_dtype == "fp32":
|
|
||||||
model_args.layernorm_dtype = torch.float32
|
|
||||||
else:
|
|
||||||
model_args.layernorm_dtype = model_args.compute_dtype
|
|
||||||
|
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
|
@ -278,7 +240,4 @@ def get_infer_args(
|
||||||
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||||
|
|
||||||
# auto-detect cuda capability
|
|
||||||
model_args.compute_dtype = _infer_dtype()
|
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
|
@ -31,11 +31,11 @@ def find_all_linear_modules(
|
||||||
|
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
layernorm_dtype: torch.dtype,
|
upcast_layernorm: bool,
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Includes:
|
Includes:
|
||||||
|
@ -44,9 +44,10 @@ def prepare_model_for_training(
|
||||||
(3) upcast the lm_head to fp32
|
(3) upcast the lm_head to fp32
|
||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||||
"""
|
"""
|
||||||
for name, param in model.named_parameters():
|
if upcast_layernorm:
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
for name, param in model.named_parameters():
|
||||||
param.data = param.data.to(layernorm_dtype)
|
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
|
|
@ -10,14 +10,15 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
|
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import GeneratingArguments
|
from llmtuner.hparams import ModelArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -30,10 +31,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: List["TrainerCallback"],
|
||||||
compute_dtype: torch.dtype,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
|
@ -41,11 +42,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
||||||
|
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.generating_args = generating_args
|
self.model_args = model_args
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.generation_config = GenerationConfig(
|
||||||
self.compute_dtype = compute_dtype
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
|
**generating_args.to_dict()
|
||||||
|
)
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||||
|
|
||||||
def ppo_train(self) -> None:
|
def ppo_train(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -74,13 +80,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
logger.info(f" Total optimization steps = {max_steps}")
|
logger.info(f" Total optimization steps = {max_steps}")
|
||||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
|
||||||
generating_args = self.generating_args.to_dict()
|
|
||||||
generating_args.update(dict(
|
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id
|
|
||||||
))
|
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
|
@ -98,7 +97,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
queries, responses = self.get_inputs(batch, generating_args)
|
queries, responses = self.get_inputs(batch)
|
||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||||
|
|
||||||
|
@ -159,27 +158,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(
|
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
self,
|
|
||||||
batch: Dict[str, torch.Tensor],
|
|
||||||
generating_args: Dict[str, Any]
|
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
gen_kwargs = dict(
|
if self.model_args.upcast_layernorm:
|
||||||
generation_config=GenerationConfig(**generating_args),
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
response: torch.Tensor = unwrapped_model.generate(
|
||||||
|
generation_config=self.generation_config,
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
**batch
|
**batch
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = batch["input_ids"]
|
if self.model_args.upcast_layernorm:
|
||||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
|
||||||
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
|
|
||||||
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
|
||||||
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
|
|
||||||
|
|
||||||
|
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||||
|
|
|
@ -1,40 +1,35 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if target == "reward": # save default head temporarily
|
||||||
valuehead_state_dict = model.v_head.state_dict()
|
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||||
|
|
||||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||||
model.v_head.load_state_dict({
|
model.v_head.load_state_dict({
|
||||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def cast_layernorm_dtype(
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
layer_norm_params = {}
|
||||||
compute_dtype: torch.dtype,
|
|
||||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
|
||||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
|
||||||
|
|
||||||
layer_norm_state_dict = {}
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
if param.data.dtype == torch.float32:
|
||||||
if layer_norm_params is None:
|
layer_norm_params[name] = param.data.detach().clone()
|
||||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
param.data = param.data.to(model.config.torch_dtype)
|
||||||
param.data = param.data.to(compute_dtype)
|
|
||||||
else:
|
|
||||||
param.data = layer_norm_params[name] # restore float32 weights
|
|
||||||
|
|
||||||
return model, layer_norm_state_dict
|
return layer_norm_params
|
||||||
|
|
||||||
|
|
||||||
|
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name in layernorm_params:
|
||||||
|
param.data = layernorm_params[name]
|
||||||
|
|
|
@ -65,10 +65,10 @@ def run_ppo(
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
ppo_trainer = CustomPPOTrainer(
|
ppo_trainer = CustomPPOTrainer(
|
||||||
|
model_args=model_args,
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [SavePeftModelCallback()],
|
||||||
compute_dtype=model_args.compute_dtype,
|
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=None,
|
ref_model=None,
|
||||||
|
|
|
@ -145,6 +145,9 @@ class Runner:
|
||||||
)
|
)
|
||||||
args[compute_type] = True
|
args[compute_type] = True
|
||||||
|
|
||||||
|
if args["quantization_bit"] is not None:
|
||||||
|
args["upcast_layernorm"] = True
|
||||||
|
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = reward_model
|
args["reward_model"] = reward_model
|
||||||
val_size = 0
|
val_size = 0
|
||||||
|
|
Loading…
Reference in New Issue