diff --git a/src/utils/common.py b/src/utils/common.py index 2bb82d6e..c81915c4 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -103,8 +103,9 @@ def _init_adapter( lastest_checkpoint = None if model_args.checkpoint_dir is not None: - if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \ - not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)): + if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)): + raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])) + if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)): raise ValueError("The given checkpoint may be not a LoRA checkpoint, \ please specify `--finetuning_type full/freeze` instead.") @@ -449,7 +450,7 @@ def preprocess_data( yield dialog def preprocess_pretrain_dataset(examples): - # build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS]) + # build grouped texts with format ` X1 X2 X3 ...` (without ) text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] concatenated_ids = list(chain(*text_ids)) total_length = len(concatenated_ids) @@ -465,7 +466,7 @@ def preprocess_data( } def preprocess_supervised_dataset(examples): - # build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]` + # build inputs with format ` X Y ` and labels with format ` ... Y ` # for input with history, we build multiple input-label pairs just like: # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 model_inputs = {"input_ids": [], "labels": []} @@ -475,15 +476,26 @@ def preprocess_data( for i in range(len(dialog) // 2): source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) - input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] + + if len(source_ids) > data_args.max_source_length - 1: # bos token + source_ids = source_ids[:data_args.max_source_length - 1] + if len(target_ids) > data_args.max_target_length - 1: # eos token + target_ids = target_ids[:data_args.max_target_length - 1] + + input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id] - model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length]) - model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length]) + if len(input_ids) > data_args.max_source_length + data_args.max_target_length: + input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length] + if len(labels) > data_args.max_source_length + data_args.max_target_length: + labels = labels[:data_args.max_source_length + data_args.max_target_length] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) return model_inputs def preprocess_unsupervised_dataset(examples): - # build inputs with format `X [BOS]` and labels with format `Y [BOS]` + # build inputs with format ` X` and labels with format ` Y` model_inputs = {"input_ids": [], "labels": []} for dialog in get_dialog(examples): prompt, answer = "".join(dialog[:-1]), dialog[-1] @@ -496,15 +508,15 @@ def preprocess_data( if len(target_ids) > data_args.max_target_length - 1: # bos token target_ids = target_ids[:data_args.max_target_length - 1] - input_ids = source_ids + [tokenizer.bos_token_id] - labels = target_ids + [tokenizer.bos_token_id] + input_ids = [tokenizer.bos_token_id] + source_ids + labels = [tokenizer.bos_token_id] + target_ids model_inputs["input_ids"].append(input_ids) model_inputs["labels"].append(labels) return model_inputs def preprocess_pairwise_dataset(examples): - # build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]` + # build input pairs with format ` X Y1 ` and ` X Y2 ` model_inputs = {"accept_ids": [], "reject_ids": []} for dialog in get_dialog(examples): prompt, answer = "".join(dialog[:-1]), dialog[-1] @@ -520,8 +532,8 @@ def preprocess_data( if len(reject_ids) > data_args.max_target_length - 1: # eos token reject_ids = reject_ids[:data_args.max_target_length - 1] - accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] - reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id] + accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id] + reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id] model_inputs["accept_ids"].append(accept_ids) model_inputs["reject_ids"].append(reject_ids) diff --git a/src/utils/other.py b/src/utils/other.py index 25603c70..3e3d25a8 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -5,13 +5,11 @@ import torch import logging from typing import Dict, List, Optional -from transformers.trainer import TRAINER_STATE_NAME +from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import LogitsProcessorList from transformers.generation.logits_process import LogitsProcessor -from peft.utils import WEIGHTS_NAME - IGNORE_INDEX = -100 VALUE_HEAD_FILE_NAME = "value_head.bin" diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index 045278d3..5652aa99 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -16,8 +16,6 @@ from transformers import ( from transformers.trainer import TRAINING_ARGS_NAME from transformers.modeling_utils import unwrap_model -from peft.utils.other import WEIGHTS_NAME - from .config import FinetuningArguments from .other import ( @@ -98,18 +96,22 @@ class PeftTrainer(Seq2SeqTrainer): logger.info(f"Saving model checkpoint to {output_dir}") model = unwrap_model(self.model) - if hasattr(model, "pretrained_model"): # for models with valuehead + if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only) backbone_model = getattr(model, "pretrained_model") + torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) else: backbone_model = model - if hasattr(backbone_model, "peft_config"): # peft methods - backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights - else: - torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights - - if hasattr(model, "v_head"): # save valuehead weights - torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) + if self.finetuning_args.finetuning_type == "lora": + backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) + else: # freeze/full tuning + backbone_model.save_pretrained( + output_dir, + state_dict=get_state_dict(backbone_model), + safe_serialization=self.args.save_safetensors + ) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: f.write(self.args.to_json_string() + "\n") @@ -122,11 +124,14 @@ class PeftTrainer(Seq2SeqTrainer): Subclass and override to inject custom behavior. It should not be directly used by external scripts. """ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") - model = unwrap_model(self.model) - if hasattr(model, "peft_config"): # peft methods - model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter")) - else: - load_trainable_params(model, self.state.best_model_checkpoint) - if hasattr(model, "v_head"): - load_valuehead_params(model, self.state.best_model_checkpoint) + model = unwrap_model(self.model) + if self.finetuning_args.finetuning_type == "lora": + model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter")) + if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "reward_head_weight"), + "summary.bias": getattr(model, "reward_head_bias") + }) + else: # freeze/full-tuning + load_trainable_params(model, self.state.best_model_checkpoint)