From c0e5df92d601966444956c65482441bd757fd7a1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 29 May 2023 17:43:16 +0800 Subject: [PATCH] fix checkpoint loading --- src/cli_demo.py | 8 ++++++- src/utils/common.py | 57 +++++++++++++++++++++++++++++---------------- src/utils/config.py | 3 ++- src/web_demo.py | 11 ++++++++- 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/cli_demo.py b/src/cli_demo.py index 2ded2682..416d6600 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -21,8 +21,14 @@ def main(): model = model.cuda() model.eval() + def format_example(query): + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) + return prompt + def predict(query, history: list): - inputs = tokenizer([query], return_tensors="pt") + inputs = tokenizer([format_example(query)], return_tensors="pt") inputs = inputs.to(model.device) gen_kwargs = { "do_sample": True, diff --git a/src/utils/common.py b/src/utils/common.py index a7b14471..7038d522 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -2,6 +2,7 @@ import os import sys import torch import hashlib +from itertools import chain from typing import List, Literal, Optional, Tuple import transformers @@ -84,6 +85,8 @@ def init_adapter( param.data = param.data.to(torch.float32) if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: + if len(model_args.checkpoint_dir) > 1: + logger.warning("Only LoRA tuning accepts multiple checkpoints.") load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods if finetuning_args.finetuning_type == "lora": @@ -117,6 +120,9 @@ def init_adapter( ) model = get_peft_model(model, lora_config) + if model_args.checkpoint_dir is not None: + logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) + return model @@ -131,19 +137,14 @@ def load_pretrained( Support both training and inference. """ - - if (not is_trainable) and (model_args.checkpoint_dir is None): - logger.warning("Checkpoint is not found at evaluation, load the original model.") - finetuning_args = FinetuningArguments(finetuning_type="none") - - if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint - for checkpoint_dir in model_args.checkpoint_dir: - if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): - raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") - logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) - finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)) - if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1: - logger.warning("Only LoRA tuning accepts multiple checkpoints.") + if finetuning_args is None: # load the fine-tuning arguments + if model_args.checkpoint_dir is None: + logger.warning("Checkpoint is not found at evaluation, load the original model.") + finetuning_args = FinetuningArguments(finetuning_type="none") + elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)): + finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)) + else: + raise ValueError("Missing fine-tuning arguments in the provided dictionary.") assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method." @@ -350,7 +351,7 @@ def preprocess_data( if examples["prompt"][i] and examples["response"][i]: query, answer = examples["prompt"][i], examples["response"][i] if examples["query"][i]: - query += examples["query"][i] + query += "\n" + examples["query"][i] prompt = "Below is an instruction that describes a task. " prompt += "Write a response that appropriately completes the request.\n" prompt += "Instruction:\n" + prefix @@ -361,6 +362,20 @@ def preprocess_data( prompt += "Human: {}\nAssistant: ".format(query) yield prompt, answer + def preprocess_pretrain_dataset(examples): + # build grouped texts with format `??` + text_ids = tokenizer(examples["prompt"])["input_ids"] + concatenated_ids = list(chain(*text_ids)) + total_length = len(concatenated_ids) + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // data_args.max_source_length) * data_args.max_source_length + # split by chunks of max_source_length + result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)] + return { + "input_ids": result, + "labels": result.copy() + } + def preprocess_supervised_dataset(examples): # build inputs with format `X Y ` and labels with format ` ... Y ` model_inputs = {"input_ids": [], "labels": []} @@ -425,7 +440,9 @@ def preprocess_data( print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))) + print("labels:\n{}".format( + tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])) + ) def print_pairwise_dataset_example(example): print("accept_ids:\n{}".format(example["accept_ids"])) @@ -437,11 +454,11 @@ def preprocess_data( print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) - if stage == "sft": - if (not training_args.do_train) and training_args.predict_with_generate: # with generation - preprocess_function = preprocess_evaluation_dataset - else: # without generation - preprocess_function = preprocess_supervised_dataset + if stage == "pt": + preprocess_function = preprocess_pretrain_dataset + elif stage == "sft": + preprocess_function = preprocess_evaluation_dataset \ + if training_args.predict_with_generate else preprocess_supervised_dataset elif stage == "rm": preprocess_function = preprocess_pairwise_dataset elif stage == "ppo": diff --git a/src/utils/config.py b/src/utils/config.py index b7012a66..579905a6 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -194,7 +194,8 @@ class FinetuningArguments: if self.name_module_trainable == "mlp": self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids] elif self.name_module_trainable == "qkv": - self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids] + self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \ + for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids] assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." diff --git a/src/web_demo.py b/src/web_demo.py index 8801ee18..5bbed0b1 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -9,8 +9,10 @@ import gradio as gr from utils import ModelArguments, auto_configure_device_map, load_pretrained from transformers import HfArgumentParser +from transformers.utils.versions import require_version +require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems parser = HfArgumentParser(ModelArguments) model_args, = parser.parse_args_into_dataclasses() model, tokenizer = load_pretrained(model_args) @@ -71,10 +73,17 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT return text +def format_example(query): + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) + return prompt + + def predict(input, chatbot, max_length, top_p, temperature, history): chatbot.append((parse_text(input), "")) - inputs = tokenizer([input], return_tensors="pt") + inputs = tokenizer([format_example(input)], return_tensors="pt") inputs = inputs.to(model.device) gen_kwargs = { "do_sample": True,