forked from p04798526/LLaMA-Factory-Mirror
fix torch gc
This commit is contained in:
parent
149610c636
commit
451b6693c0
|
@ -212,12 +212,17 @@ def has_tokenized_data(path: os.PathLike) -> bool:
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
r"""
|
r"""
|
||||||
Collects GPU memory.
|
Collects GPU or NPU memory.
|
||||||
"""
|
"""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if is_torch_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
elif is_torch_cuda_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||||
|
|
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
|
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||||
embedding_dim = embed_weight.size(1)
|
embedding_dim = embed_weight.size(1)
|
||||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||||
|
|
|
@ -23,6 +23,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||||
"""
|
"""
|
||||||
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||||
|
err_text = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
@ -31,16 +32,16 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||||
return {key: f.get_tensor(key) for key in f.keys()}
|
return {key: f.get_tensor(key) for key in f.keys()}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
err_text = str(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||||
return torch.load(vhead_file, map_location="cpu")
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
err_text = str(err)
|
||||||
|
|
||||||
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
|
||||||
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue