fix evaluator and cached_file in 4.31.0

This commit is contained in:
hiyouga 2023-11-18 19:39:23 +08:00
parent a2019c8b61
commit ff6056405d
2 changed files with 37 additions and 12 deletions

View File

@ -3,6 +3,7 @@
import os
import json
import torch
import inspect
import tiktoken
import numpy as np
from tqdm import tqdm, trange
@ -45,13 +46,18 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
if "token" in inspect.signature(cached_file).parameters:
kwargs = {"token": self.model_args.hf_hub_token}
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
revision=self.model_args.model_revision
**kwargs
)
with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
@ -62,7 +68,9 @@ class Evaluator:
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
download_mode="force_redownload"
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []

View File

@ -1,4 +1,5 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
@ -94,20 +95,36 @@ def load_valuehead_params(
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
"cache_dir": model_args.cache_dir
}
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training(
model: "PreTrainedModel",