fix evaluator and cached_file in 4.31.0
This commit is contained in:
parent
a2019c8b61
commit
ff6056405d
|
@ -3,6 +3,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
import inspect
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
@ -45,13 +46,18 @@ class Evaluator:
|
||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
|
if "token" in inspect.signature(cached_file).parameters:
|
||||||
|
kwargs = {"token": self.model_args.hf_hub_token}
|
||||||
|
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
||||||
|
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
|
||||||
|
|
||||||
mapping = cached_file(
|
mapping = cached_file(
|
||||||
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
filename="mapping.json",
|
filename="mapping.json",
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
token=self.model_args.hf_hub_token,
|
**kwargs
|
||||||
revision=self.model_args.model_revision
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
|
@ -62,7 +68,9 @@ class Evaluator:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
name=subject,
|
name=subject,
|
||||||
download_mode="force_redownload"
|
cache_dir=self.model_args.cache_dir,
|
||||||
|
download_mode=self.eval_args.download_mode,
|
||||||
|
token=self.model_args.hf_hub_token
|
||||||
)
|
)
|
||||||
pbar.set_postfix_str(categorys[subject]["name"])
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
@ -94,20 +95,36 @@ def load_valuehead_params(
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"path_or_repo_id": path_or_repo_id,
|
"path_or_repo_id": path_or_repo_id,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir
|
||||||
"token": model_args.hf_hub_token
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "token" in inspect.signature(cached_file).parameters:
|
||||||
|
kwargs["token"] = model_args.hf_hub_token
|
||||||
|
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
||||||
|
kwargs["use_auth_token"] = model_args.hf_hub_token
|
||||||
|
else:
|
||||||
|
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||||
except:
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
|
except Exception as err:
|
||||||
|
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from safetensors import safe_open
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
except:
|
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||||
|
return {
|
||||||
|
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
||||||
|
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
|
||||||
|
}
|
||||||
|
except Exception as err:
|
||||||
|
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return torch.load(vhead_file, map_location="cpu")
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
|
Loading…
Reference in New Issue