fix evaluator and cached_file in 4.31.0

This commit is contained in:
hiyouga 2023-11-18 19:39:23 +08:00
parent a2019c8b61
commit ff6056405d
2 changed files with 37 additions and 12 deletions

View File

@ -3,6 +3,7 @@
import os import os
import json import json
import torch import torch
import inspect
import tiktoken import tiktoken
import numpy as np import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -45,13 +46,18 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None: def eval(self) -> None:
if "token" in inspect.signature(cached_file).parameters:
kwargs = {"token": self.model_args.hf_hub_token}
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
mapping = cached_file( mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task), path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json", filename="mapping.json",
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token, **kwargs
revision=self.model_args.model_revision
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
@ -62,7 +68,9 @@ class Evaluator:
dataset = load_dataset( dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task), path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject, name=subject,
download_mode="force_redownload" cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []

View File

@ -1,4 +1,5 @@
import torch import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file from transformers.utils import cached_file
@ -94,20 +95,36 @@ def load_valuehead_params(
""" """
kwargs = { kwargs = {
"path_or_repo_id": path_or_repo_id, "path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir
"token": model_args.hf_hub_token
} }
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try: try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except: return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
try: try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except: with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)) logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training( def prepare_model_for_training(
model: "PreTrainedModel", model: "PreTrainedModel",