This commit is contained in:
hiyouga 2023-07-28 17:02:26 +08:00
parent 553b97a9d5
commit 91dd17d8a6
1 changed files with 11 additions and 7 deletions

View File

@ -28,7 +28,7 @@ class AverageMeter:
self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True).
# Avoids runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@ -63,7 +63,7 @@ def print_trainable_params(model: torch.nn.Module) -> None:
def prepare_model_for_training(
model: PreTrainedModel,
finetuning_type: str,
output_embedding_layer_name: Optional[str] = "lm_head",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel:
@ -83,19 +83,23 @@ def prepare_model_for_training(
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
if finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
new_output_layer = CastOutputToFloat(output_layer)
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
setattr(new_output_layer, "weight", output_layer.weight)
setattr(model, output_layer_name, new_output_layer)
return model
def torch_gc() -> None:
r"""
Collects GPU memory.