export use balanced gpu
This commit is contained in:
parent
9658c63cd9
commit
3e84f430b1
|
@ -228,10 +228,10 @@ def _prepare_model_for_training(
|
||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||||
"""
|
"""
|
||||||
if model_args.upcast_layernorm:
|
if model_args.upcast_layernorm:
|
||||||
|
logger.info("Upcasting layernorm weights in float32.")
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
|
||||||
|
|
||||||
if not model_args.disable_gradient_checkpointing:
|
if not model_args.disable_gradient_checkpointing:
|
||||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
|
@ -249,6 +249,7 @@ def _prepare_model_for_training(
|
||||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||||
return output.to(torch.float32)
|
return output.to(torch.float32)
|
||||||
|
|
||||||
|
logger.info("Upcasting lm_head outputs in float32.")
|
||||||
output_layer = getattr(model, output_layer_name)
|
output_layer = getattr(model, output_layer_name)
|
||||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||||
|
@ -287,12 +288,7 @@ def patch_config(
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = True
|
init_kwargs["low_cpu_mem_usage"] = True
|
||||||
if "device_map" not in init_kwargs:
|
if "device_map" not in init_kwargs:
|
||||||
if is_trainable:
|
init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto"
|
||||||
init_kwargs["device_map"] = {"": get_current_device()}
|
|
||||||
elif model_args.export_dir is None:
|
|
||||||
init_kwargs["device_map"] = "auto"
|
|
||||||
else:
|
|
||||||
init_kwargs["device_map"] = {"": "cpu"}
|
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
|
|
Loading…
Reference in New Issue