fix #268
This commit is contained in:
parent
553b97a9d5
commit
91dd17d8a6
|
@ -28,7 +28,7 @@ class AverageMeter:
|
|||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoid runtime error in model.generate(do_sample=True).
|
||||
# Avoids runtime error in model.generate(do_sample=True).
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
@ -63,7 +63,7 @@ def print_trainable_params(model: torch.nn.Module) -> None:
|
|||
def prepare_model_for_training(
|
||||
model: PreTrainedModel,
|
||||
finetuning_type: str,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> PreTrainedModel:
|
||||
|
@ -83,19 +83,23 @@ def prepare_model_for_training(
|
|||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||
input_dtype = output_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
|
||||
new_output_layer = CastOutputToFloat(output_layer)
|
||||
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
|
||||
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
|
||||
setattr(new_output_layer, "weight", output_layer.weight)
|
||||
setattr(model, output_layer_name, new_output_layer)
|
||||
return model
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
|
|
Loading…
Reference in New Issue