forked from p04798526/LLaMA-Factory-Mirror
fix torch gc
This commit is contained in:
parent
149610c636
commit
451b6693c0
|
@ -212,12 +212,17 @@ def has_tokenized_data(path: os.PathLike) -> bool:
|
|||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_cuda_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||
|
|
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
|
||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
|
|
|
@ -23,6 +23,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
|||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||
err_text = ""
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
@ -31,16 +32,16 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
|||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {key: f.get_tensor(key) for key in f.keys()}
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||
err_text = str(err)
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||
err_text = str(err)
|
||||
|
||||
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
|
||||
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue