This commit is contained in:
hiyouga 2024-04-01 22:53:52 +08:00
parent 54b7d34908
commit 4a6ca621c0
4 changed files with 23 additions and 15 deletions

View File

@ -235,6 +235,12 @@ def _configure_quantization(
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
@ -263,14 +269,10 @@ def _prepare_model_for_training(
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook)
output_layer.register_forward_hook(_fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@ -316,13 +318,6 @@ def patch_config(
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
(gen_config.temperature is not None and gen_config.temperature != 1.0)
@ -331,6 +326,16 @@ def patch_model(
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if is_trainable and getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if is_trainable and getattr(model.config, "model_type", None) == "qwen2" and model_args.flash_attn:
setattr(model.config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if is_trainable and model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer)

View File

@ -95,7 +95,10 @@ class CustomDPOTrainer(DPOTrainer):
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
all_logps = self.get_batch_logps(

View File

@ -73,7 +73,7 @@ class CustomORPOTrainer(DPOTrainer):
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(

View File

@ -353,7 +353,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)