fix #1742
This commit is contained in:
parent
7ae6919b9b
commit
870426ff70
|
@ -81,9 +81,16 @@ def load_model_and_tokenizer(
|
|||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
patcher.patch_valuehead_model(model)
|
||||
vhead_params = load_valuehead_params(model_args)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
vhead_path = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
vhead_path = model_args.model_name_or_path
|
||||
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
|
|
|
@ -85,34 +85,21 @@ def get_modelcard_args(
|
|||
}
|
||||
|
||||
|
||||
def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
path_or_repo_id = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
path_or_repo_id = model_args.model_name_or_path
|
||||
|
||||
kwargs = {
|
||||
"path_or_repo_id": path_or_repo_id,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
|
||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {
|
||||
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
||||
|
@ -121,6 +108,12 @@ def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tenso
|
|||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||
return None
|
||||
|
||||
|
|
|
@ -206,6 +206,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
if self.finetuning_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
|
@ -220,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
|
@ -228,7 +233,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
||||
return queries, responses
|
||||
|
|
Loading…
Reference in New Issue