diff --git a/README.md b/README.md index 010ade70..b14d4a12 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ huggingface-cli login - Python 3.8+ and PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT and TRL - sentencepiece, protobuf and tiktoken -- fire, jieba, rouge-chinese and nltk (used at evaluation and predict) +- jieba, rouge-chinese and nltk (used at evaluation and predict) - gradio and matplotlib (used in web UI) - uvicorn, fastapi and sse-starlette (used in API) diff --git a/README_zh.md b/README_zh.md index d8c5fe1c..86b87b62 100644 --- a/README_zh.md +++ b/README_zh.md @@ -158,7 +158,7 @@ huggingface-cli login - Python 3.8+ 和 PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL - sentencepiece, protobuf 和 tiktoken -- fire, jieba, rouge-chinese 和 nltk (用于评估及预测) +- jieba, rouge-chinese 和 nltk (用于评估及预测) - gradio 和 matplotlib (用于网页端交互) - uvicorn, fastapi 和 sse-starlette (用于 API) diff --git a/requirements.txt b/requirements.txt index 1c1bd308..ca876344 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,13 +3,12 @@ transformers>=4.31.0,<4.35.0 datasets>=2.14.0 accelerate>=0.21.0 peft>=0.6.0 -trl==0.7.2 +trl>=0.7.4 gradio>=3.38.0,<4.0.0 scipy sentencepiece protobuf tiktoken -fire jieba rouge-chinese nltk diff --git a/src/evaluate.py b/src/evaluate.py index 8af8c12c..13796c0c 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -1,190 +1,10 @@ -# coding=utf-8 -# Evaluates the performance of pre-trained models. -# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla -# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result -# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py - -import os -import fire -import json -import torch -import numpy as np -import transformers -from collections import Counter -from datasets import load_dataset -from dataclasses import dataclass -from tqdm import tqdm, trange -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple - -from llmtuner import ChatModel - -if TYPE_CHECKING: - from datasets import Dataset +from llmtuner import Evaluator -choices = ["A", "B", "C", "D"] - - -@dataclass -class EvalTemplate: - - system: str - choice: str - answer: str - prefix: str - - def parse_example( - self, - example: Dict[str, str] - ) -> Tuple[str, str]: - candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example] - return "".join([example["question"]] + candidates + [self.answer]), example["answer"] - - def format_example( - self, - target_data: Dict[str, str], - support_set: "Dataset", - subject_name: str, - use_history: bool - ) -> Tuple[str, str, List[Tuple[str, str]]]: - query, resp = self.parse_example(target_data) - history = [self.parse_example(support_set[k]) for k in range(len(support_set))] - - if len(history): - temp = history.pop(0) - history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) - else: - query = self.system.format(subject=subject_name) + query - - if not use_history: - query = "\n\n".join(["".join(item) for item in history] + [query]) - history = [] - return query.strip(), resp, history - - -eval_templates = { - "en": EvalTemplate( - system="The following are multiple choice questions (with answers) about {subject}.\n\n", - choice="\n{choice}. {content}", - answer="\nAnswer: ", - prefix=" " - ), - "zh": EvalTemplate( - system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", - choice="\n{choice}. {content}", - answer="\n答案:", - prefix="\n" - ) -} - - -@torch.inference_mode() -def batch_inference( - chat_model: ChatModel, - batch_input: Dict[str, torch.Tensor], - prefix_char: str -) -> List[str]: - logits = chat_model.model(**batch_input).logits - lengths = torch.sum(batch_input["attention_mask"], dim=-1) - nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) - probs = torch.nn.functional.softmax( - torch.stack( - [ - nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]] - for choice in choices - ], - dim=-1 - ), - dim=-1 - ).detach() - return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)] - - -def evaluate( - model_name_or_path: str, - finetuning_type: Optional[str] = "lora", - checkpoint_dir: Optional[str] = None, - template: Optional[str] = "vanilla", - task: Optional[str] = "ceval", - dataset_dir: Optional[str] = "evaluation", - split: Optional[Literal["validation", "test"]] = "validation", - lang: Optional[Literal["zh", "en"]] = "zh", - n_shot: Optional[int] = 5, - n_avg: Optional[int] = 1, - batch_size: Optional[int] = 4, - save_name: Optional[str] = None, - seed: Optional[int] = 42 -): - with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: - categorys: Dict[str, Dict[str, str]] = json.load(f) - - transformers.set_seed(seed) - chat_model = ChatModel(dict( - model_name_or_path=model_name_or_path, - finetuning_type=finetuning_type, - checkpoint_dir=checkpoint_dir, - template=template - )) - chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 - eval_template = eval_templates[lang] - - category_corrects: Dict[str, np.ndarray] = { - subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"] - } - pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) - results = {} - for subject in pbar: - dataset = load_dataset(os.path.join(dataset_dir, task), subject) - labels, answers, all_outputs = [], [], [] - for epoch in range(n_avg): - pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch)) - inputs, outputs = [], [] - for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False): - support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"])))) - query, resp, history = eval_template.format_example( - target_data=dataset[split][i], - support_set=support_set, - subject_name=categorys[subject]["name"], - use_history=chat_model.template.use_history - ) - input_ids, _ = chat_model.template.encode_oneturn( - tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history - ) - inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) - if epoch == 0: - labels.append(resp) - - for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False): - batch_input = chat_model.tokenizer.pad( - inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt" - ).to(chat_model.model.device) - preds = batch_inference(chat_model, batch_input, eval_template.prefix) - outputs += preds - all_outputs.append(outputs) - - for i in range(len(all_outputs[0])): - count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)]) - answers.append(count.most_common(1)[0][0]) - - corrects = (np.array(answers) == np.array(labels)) - category_name = categorys[subject]["category"] - category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) - category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0) - results[subject] = {str(i): answers[i] for i in range(len(answers))} - - score_info = "\n".join([ - "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) - for category_name, category_correct in category_corrects.items() if len(category_correct) - ]) - - print(score_info) - if save_name is not None: - with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f: - json.dump(results, f, indent=2) - - with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f: - f.write(score_info) +def main(): + evaluator = Evaluator() + evaluator.eval() if __name__ == "__main__": - fire.Fire(evaluate) + main() diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 93858cc5..fde80959 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,7 +1,8 @@ -# Level: api, webui > chat > tuner > dsets > extras, hparams +# Level: api, webui > chat, eval > tuner > dsets > extras, hparams from llmtuner.api import create_app from llmtuner.chat import ChatModel +from llmtuner.eval import Evaluator from llmtuner.tuner import export_model, run_exp from llmtuner.webui import create_ui, create_web_demo diff --git a/src/llmtuner/eval/__init__.py b/src/llmtuner/eval/__init__.py new file mode 100644 index 00000000..10584817 --- /dev/null +++ b/src/llmtuner/eval/__init__.py @@ -0,0 +1 @@ +from llmtuner.eval.engine import Evaluator diff --git a/src/llmtuner/eval/constants.py b/src/llmtuner/eval/constants.py new file mode 100644 index 00000000..433ad39b --- /dev/null +++ b/src/llmtuner/eval/constants.py @@ -0,0 +1,3 @@ +CHOICES = ["A", "B", "C", "D"] + +SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] diff --git a/src/llmtuner/eval/engine.py b/src/llmtuner/eval/engine.py new file mode 100644 index 00000000..bfb5e17e --- /dev/null +++ b/src/llmtuner/eval/engine.py @@ -0,0 +1,110 @@ +# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py + +import os +import json +import torch +import tiktoken +import numpy as np +from tqdm import tqdm, trange +from datasets import load_dataset +from typing import Any, Dict, List, Optional + +from llmtuner.eval.constants import CHOICES, SUBJECTS +from llmtuner.eval.parser import get_eval_args +from llmtuner.eval.template import get_eval_template +from llmtuner.extras.misc import dispatch_model +from llmtuner.extras.template import get_template_and_fix_tokenizer +from llmtuner.tuner.core import load_model_and_tokenizer + + +class Evaluator: + + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) + self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 + self.model = dispatch_model(self.model) + self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) + self.eval_template = get_eval_template(self.eval_args.lang) + self.choice_inputs = self._encode_choices() + + def _encode_choices(self) -> List[int]: + if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=False) + + return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES] + + @torch.inference_mode() + def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: + logits = self.model(**batch_input).logits + lengths = torch.sum(batch_input["attention_mask"], dim=-1) + word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) + choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach() + return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] + + def eval(self) -> None: + mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json") + with open(mapping, "r", encoding="utf-8") as f: + categorys: Dict[str, Dict[str, str]] = json.load(f) + + category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} + pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) + results = {} + for subject in pbar: + dataset = load_dataset( + path=os.path.join(self.eval_args.task_dir, self.eval_args.task), + name=subject, + download_mode="force_redownload" + ) + pbar.set_postfix_str(categorys[subject]["name"]) + inputs, outputs, labels = [], [], [] + for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): + support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) + query, resp, history = self.eval_template.format_example( + target_data=dataset[self.data_args.split][i], + support_set=support_set, + subject_name=categorys[subject]["name"], + use_history=self.template.use_history + ) + input_ids, _ = self.template.encode_oneturn( + tokenizer=self.tokenizer, query=query, resp=resp, history=history + ) + inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) + labels.append(resp) + + for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False): + batch_input = self.tokenizer.pad( + inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt" + ).to(self.model.device) + preds = self.batch_inference(batch_input) + outputs += preds + + corrects = (np.array(outputs) == np.array(labels)) + category_name = categorys[subject]["category"] + category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) + category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0) + results[subject] = {str(i): outputs[i] for i in range(len(outputs))} + + pbar.close() + self._save_results(category_corrects, results) + + def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: + score_info = "\n".join([ + "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) + for category_name, category_correct in category_corrects.items() if len(category_correct) + ]) + print(score_info) + if self.eval_args.save_dir is not None: + os.makedirs(self.eval_args.save_dir, exist_ok=False) + with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f: + json.dump(results, f, indent=2) + + with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f: + f.write(score_info) + + +if __name__ == "__main__": + evaluator = Evaluator() + evaluator.eval() diff --git a/src/llmtuner/eval/parser.py b/src/llmtuner/eval/parser.py new file mode 100644 index 00000000..cef38048 --- /dev/null +++ b/src/llmtuner/eval/parser.py @@ -0,0 +1,49 @@ +import transformers +from typing import Any, Dict, Optional, Tuple +from transformers import HfArgumentParser + +from llmtuner.extras.misc import parse_args +from llmtuner.hparams import ( + ModelArguments, + DataArguments, + EvaluationArguments, + FinetuningArguments +) + + +def parse_eval_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + EvaluationArguments, + FinetuningArguments +]: + parser = HfArgumentParser(( + ModelArguments, + DataArguments, + EvaluationArguments, + FinetuningArguments + )) + return parse_args(parser, args) + + +def get_eval_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + EvaluationArguments, + FinetuningArguments +]: + model_args, data_args, eval_args, finetuning_args = parse_eval_args(args) + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + transformers.set_seed(eval_args.seed) + + return model_args, data_args, eval_args, finetuning_args diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py new file mode 100644 index 00000000..44cb3c6d --- /dev/null +++ b/src/llmtuner/eval/template.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Tuple + +from llmtuner.eval.constants import CHOICES + +if TYPE_CHECKING: + from datasets import Dataset + + +@dataclass +class EvalTemplate: + + system: str + choice: str + answer: str + prefix: str + + def parse_example( + self, + example: Dict[str, str] + ) -> Tuple[str, str]: + candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] + return "".join([example["question"]] + candidates + [self.answer]), example["answer"] + + def format_example( + self, + target_data: Dict[str, str], + support_set: "Dataset", + subject_name: str, + use_history: bool + ) -> Tuple[str, str, List[Tuple[str, str]]]: + query, resp = self.parse_example(target_data) + history = [self.parse_example(support_set[k]) for k in range(len(support_set))] + + if len(history): + temp = history.pop(0) + history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) + else: + query = self.system.format(subject=subject_name) + query + + if not use_history: + query = "\n\n".join(["".join(item) for item in history] + [query]) + history = [] + return query.strip(), resp, history + + +eval_templates: Dict[str, EvalTemplate] = {} + + +def register_eval_template( + name: str, + system: str, + choice: str, + answer: str, + prefix: str +) -> None: + eval_templates[name] = EvalTemplate( + system=system, + choice=choice, + answer=answer, + prefix=prefix + ) + + +def get_eval_template(name: str) -> EvalTemplate: + eval_template = eval_templates.get(name, None) + assert eval_template is not None, "Template {} does not exist.".format(name) + return eval_template + + +register_eval_template( + name="en", + system="The following are multiple choice questions (with answers) about {subject}.\n\n", + choice="\n{choice}. {content}", + answer="\nAnswer: ", + prefix=" " +) + + +register_eval_template( + name="zh", + system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", + choice="\n{choice}. {content}", + answer="\n答案:", + prefix="\n" +) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 960d43ee..6300bc75 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,6 +1,8 @@ import gc +import os +import sys import torch -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList try: @@ -17,6 +19,7 @@ except ImportError: _is_bf16_available = torch.cuda.is_bf16_supported() if TYPE_CHECKING: + from transformers import HfArgumentParser from transformers.modeling_utils import PreTrainedModel @@ -74,7 +77,7 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: return torch.float32 -def get_logits_processor() -> LogitsProcessorList: +def get_logits_processor() -> "LogitsProcessorList": r""" Gets logits processor that removes NaN and Inf logits. """ @@ -93,6 +96,17 @@ def torch_gc() -> None: torch.cuda.ipc_collect() +def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + return parser.parse_args_into_dataclasses() + + def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": r""" Dispatches a pre-trained model to GPUs with balanced memory. diff --git a/src/llmtuner/hparams/__init__.py b/src/llmtuner/hparams/__init__.py index f0547cc5..623d6517 100644 --- a/src/llmtuner/hparams/__init__.py +++ b/src/llmtuner/hparams/__init__.py @@ -1,4 +1,5 @@ from .data_args import DataArguments +from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 4c67dd65..fb8a0abc 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -42,7 +42,7 @@ class DataArguments: ) dataset_dir: Optional[str] = field( default="data", - metadata={"help": "The name of the folder containing datasets."} + metadata={"help": "Path to the folder containing the datasets."} ) split: Optional[str] = field( default="train", diff --git a/src/llmtuner/hparams/evaluation_args.py b/src/llmtuner/hparams/evaluation_args.py new file mode 100644 index 00000000..5f507698 --- /dev/null +++ b/src/llmtuner/hparams/evaluation_args.py @@ -0,0 +1,55 @@ +import os +from typing import Literal, Optional +from dataclasses import dataclass, field + +from datasets import DownloadMode + + +@dataclass +class EvaluationArguments: + r""" + Arguments pertaining to specify the evaluation parameters. + """ + task: str = field( + metadata={"help": "Name of the evaluation task."} + ) + task_dir: Optional[str] = field( + default="evaluation", + metadata={"help": "Path to the folder containing the evaluation datasets."} + ) + batch_size: Optional[int] = field( + default=4, + metadata={"help": "The batch size per GPU for evaluation."} + ) + seed: Optional[int] = field( + default=42, + metadata={"help": "Random seed to be used with data loaders."} + ) + lang: Optional[Literal["en", "zh"]] = field( + default="en", + metadata={"help": "Language used at evaluation."} + ) + n_shot: Optional[int] = field( + default=5, + metadata={"help": "Number of examplars for few-shot learning."} + ) + save_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to save the evaluation results."} + ) + download_mode: Optional[DownloadMode] = field( + default=DownloadMode.REUSE_DATASET_IF_EXISTS, + metadata={"help": "Download mode used for the evaluation datasets."} + ) + + def __post_init__(self): + task_available = [] + for folder in os.listdir(self.task_dir): + if os.path.isdir(os.path.join(self.task_dir, folder)): + task_available.append(folder) + + if self.task not in task_available: + raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir)) + + if self.save_dir is not None and os.path.exists(self.save_dir): + raise ValueError("`save_dir` already exists, use another one.") diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 28a69c8d..eb949626 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -12,7 +12,7 @@ class FinetuningArguments: default="sft", metadata={"help": "Which stage will be performed in training."} ) - finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( + finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( default="lora", metadata={"help": "Which fine-tuning method to use."} ) diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 8a771567..d3799f24 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -38,12 +38,13 @@ def init_adapter( if (not is_trainable) and model_args.checkpoint_dir is None: logger.info("Checkpoint is not found at evaluation, load the original model.") + return model if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") model = model.float() - if finetuning_args.finetuning_type == "freeze": + if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") num_layers = getattr(model.config, "num_layers") if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 34bc2a6e..3a51a698 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -42,7 +42,7 @@ require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transform require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0") -require_version("trl==0.7.2", "To fix: pip install trl==0.7.2") +require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4") def load_model_and_tokenizer( diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 0ec59127..04fc884b 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -1,5 +1,4 @@ import os -import sys import torch import datasets import transformers @@ -8,6 +7,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.trainer_utils import get_last_checkpoint from llmtuner.extras.logging import get_logger +from llmtuner.extras.misc import parse_args from llmtuner.hparams import ( ModelArguments, DataArguments, @@ -19,17 +19,6 @@ from llmtuner.hparams import ( logger = get_logger(__name__) -def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: - if args is not None: - return parser.parse_dict(args) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - return parser.parse_json_file(os.path.abspath(sys.argv[1])) - else: - return parser.parse_args_into_dataclasses() - - def parse_train_args( args: Optional[Dict[str, Any]] = None ) -> Tuple[ @@ -46,7 +35,7 @@ def parse_train_args( FinetuningArguments, GeneratingArguments )) - return _parse_args(parser, args) + return parse_args(parser, args) def parse_infer_args( @@ -63,7 +52,7 @@ def parse_infer_args( FinetuningArguments, GeneratingArguments )) - return _parse_args(parser, args) + return parse_args(parser, args) def get_train_args( diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 75fc4e25..c2b0b581 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -1,5 +1,4 @@ import torch -from copy import deepcopy from collections import defaultdict from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from transformers import BatchEncoding, Trainer @@ -10,7 +9,6 @@ from llmtuner.extras.constants import IGNORE_INDEX if TYPE_CHECKING: from transformers import PreTrainedModel - from trl import PreTrainedModelWrapper class CustomDPOTrainer(DPOTrainer): @@ -49,39 +47,6 @@ class CustomDPOTrainer(DPOTrainer): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"): - # adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) - if model is not None: - if hasattr(model, "config"): - hidden_size = ( - max(model.config.hidden_sizes) - if getattr(model.config, "hidden_sizes", None) - else getattr(model.config, "hidden_size", None) - ) - if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: - # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` - # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 - config_kwargs.update( - { - "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, - "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, - "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, - } - ) - - # If ZeRO-3 is used, we shard both the active and reference model. - # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) - if config_kwargs["zero_optimization"]["stage"] != 3: - config_kwargs["zero_optimization"]["stage"] = 0 - - # Lazy load - import deepspeed # type: ignore - model, *_ = deepspeed.initialize(model=model, config=config_kwargs) - model.eval() - return model - def concatenated_forward( self, model: Optional[torch.nn.Module] = None, diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 08e1175a..f9626b1d 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -226,7 +226,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): replace_model(unwrapped_model, target="default") return rewards - @PPODecorators.empty_cuda_cache() + @PPODecorators.empty_device_cache() def batched_forward_pass( self, model: "AutoModelForCausalLMWithValueHead", diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 4c35f628..9e5a5979 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -42,7 +42,7 @@ def run_ppo( ppo_epochs=1, max_grad_norm=training_args.max_grad_norm, seed=training_args.seed, - optimize_cuda_cache=True, + optimize_device_cache=True, target=finetuning_args.ppo_target, log_with=finetuning_args.ppo_logger, use_score_scaling=finetuning_args.ppo_score_norm,