diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index c61f28f0..81097257 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -228,10 +228,10 @@ def _prepare_model_for_training( Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 """ if model_args.upcast_layernorm: + logger.info("Upcasting layernorm weights in float32.") for name, param in model.named_parameters(): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): param.data = param.data.to(torch.float32) - logger.info("Upcasting layernorm weights in float32.") if not model_args.disable_gradient_checkpointing: if not getattr(model, "supports_gradient_checkpointing", False): @@ -249,6 +249,7 @@ def _prepare_model_for_training( def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): return output.to(torch.float32) + logger.info("Upcasting lm_head outputs in float32.") output_layer = getattr(model, output_layer_name) if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: output_layer.register_forward_hook(fp32_forward_post_hook) @@ -287,12 +288,7 @@ def patch_config( if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = True if "device_map" not in init_kwargs: - if is_trainable: - init_kwargs["device_map"] = {"": get_current_device()} - elif model_args.export_dir is None: - init_kwargs["device_map"] = "auto" - else: - init_kwargs["device_map"] = {"": "cpu"} + init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto" def patch_model(