export use balanced gpu
This commit is contained in:
parent
9658c63cd9
commit
3e84f430b1
|
@ -228,10 +228,10 @@ def _prepare_model_for_training(
|
|||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
|
@ -249,6 +249,7 @@ def _prepare_model_for_training(
|
|||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
@ -287,12 +288,7 @@ def patch_config(
|
|||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = True
|
||||
if "device_map" not in init_kwargs:
|
||||
if is_trainable:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
elif model_args.export_dir is None:
|
||||
init_kwargs["device_map"] = "auto"
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": "cpu"}
|
||||
init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto"
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
|
Loading…
Reference in New Issue