export use balanced gpu

This commit is contained in:
hiyouga 2024-03-06 16:33:14 +08:00
parent 9658c63cd9
commit 3e84f430b1
1 changed files with 3 additions and 7 deletions

View File

@ -228,10 +228,10 @@ def _prepare_model_for_training(
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
logger.info("Upcasting layernorm weights in float32.")
if not model_args.disable_gradient_checkpointing: if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False): if not getattr(model, "supports_gradient_checkpointing", False):
@ -249,6 +249,7 @@ def _prepare_model_for_training(
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32) return output.to(torch.float32)
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name) output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook) output_layer.register_forward_hook(fp32_forward_post_hook)
@ -287,12 +288,7 @@ def patch_config(
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = True init_kwargs["low_cpu_mem_usage"] = True
if "device_map" not in init_kwargs: if "device_map" not in init_kwargs:
if is_trainable: init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto"
init_kwargs["device_map"] = {"": get_current_device()}
elif model_args.export_dir is None:
init_kwargs["device_map"] = "auto"
else:
init_kwargs["device_map"] = {"": "cpu"}
def patch_model( def patch_model(