From 451b6693c0cb86cc9ac03d1a9389cf1fd2b918ec Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 20:30:25 +0800 Subject: [PATCH] fix torch gc --- src/llamafactory/extras/misc.py | 11 ++++++++--- src/llamafactory/model/utils/embedding.py | 2 +- src/llamafactory/model/utils/valuehead.py | 9 +++++---- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 78f71847..48476f9c 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -212,12 +212,17 @@ def has_tokenized_data(path: os.PathLike) -> bool: def torch_gc() -> None: r""" - Collects GPU memory. + Collects GPU or NPU memory. """ gc.collect() - if torch.cuda.is_available(): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(): + torch.mps.empty_cache() + elif is_torch_cuda_available(): torch.cuda.empty_cache() - torch.cuda.ipc_collect() def try_download_model_from_ms(model_args: "ModelArguments") -> str: diff --git a/src/llamafactory/model/utils/embedding.py b/src/llamafactory/model/utils/embedding.py index 357c9cc0..3d9278e3 100644 --- a/src/llamafactory/model/utils/embedding.py +++ b/src/llamafactory/model/utils/embedding.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None: +def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: embedding_dim = embed_weight.size(1) avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) diff --git a/src/llamafactory/model/utils/valuehead.py b/src/llamafactory/model/utils/valuehead.py index d813729e..64333688 100644 --- a/src/llamafactory/model/utils/valuehead.py +++ b/src/llamafactory/model/utils/valuehead.py @@ -23,6 +23,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. """ kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} + err_text = "" try: from safetensors import safe_open @@ -31,16 +32,16 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> with safe_open(vhead_file, framework="pt", device="cpu") as f: return {key: f.get_tensor(key) for key in f.keys()} except Exception as err: - logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) + err_text = str(err) try: vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) return torch.load(vhead_file, map_location="cpu") except Exception as err: - logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) + err_text = str(err) - logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) - logger.info("Ignore these messages if you are not resuming the training of a value head model.") + logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) + logger.info("Ignore the above message if you are not resuming the training of a value head model.") return None