fix torch gc

This commit is contained in:
hiyouga 2024-06-06 20:30:25 +08:00
parent 149610c636
commit 451b6693c0
3 changed files with 14 additions and 8 deletions

View File

@ -212,12 +212,17 @@ def has_tokenized_data(path: os.PathLike) -> bool:
def torch_gc() -> None:
r"""
Collects GPU memory.
Collects GPU or NPU memory.
"""
gc.collect()
if torch.cuda.is_available():
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_cuda_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def try_download_model_from_ms(model_args: "ModelArguments") -> str:

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])

View File

@ -23,6 +23,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
err_text = ""
try:
from safetensors import safe_open
@ -31,16 +32,16 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
err_text = str(err)
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
return None