From f75137661358f9070bc70c341dfa2cc5fd69cf94 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 15 Jul 2023 16:54:28 +0800 Subject: [PATCH] modity code structure --- README.md | 19 +- pyproject.toml | 3 + requirements.txt | 5 +- setup.py | 55 ++ src/api_demo.py | 218 +----- src/cli_demo.py | 14 +- src/export_model.py | 8 +- src/llmtuner/__init__.py | 7 + src/llmtuner/api/__init__.py | 1 + src/llmtuner/api/app.py | 152 +++++ src/llmtuner/api/protocol.py | 73 +++ src/llmtuner/dsets/__init__.py | 2 + src/llmtuner/dsets/callbacks.py | 63 ++ src/llmtuner/dsets/loader.py | 106 +++ src/llmtuner/dsets/preprocess.py | 172 +++++ src/{ => llmtuner/extras}/__init__.py | 0 src/llmtuner/extras/callbacks.py | 72 ++ src/llmtuner/extras/constants.py | 7 + src/llmtuner/extras/logging.py | 18 + src/llmtuner/extras/misc.py | 105 +++ src/llmtuner/extras/ploting.py | 50 ++ src/llmtuner/extras/save_and_load.py | 49 ++ src/{utils => llmtuner/extras}/template.py | 0 src/llmtuner/hparams/__init__.py | 5 + src/llmtuner/hparams/data_args.py | 119 ++++ src/llmtuner/hparams/finetuning_args.py | 78 +++ src/llmtuner/hparams/general_args.py | 13 + src/llmtuner/hparams/generating_args.py | 51 ++ src/llmtuner/hparams/model_args.py | 72 ++ src/llmtuner/tuner/__init__.py | 5 + src/llmtuner/tuner/core/__init__.py | 2 + src/llmtuner/tuner/core/adapter.py | 94 +++ src/llmtuner/tuner/core/loader.py | 151 +++++ src/llmtuner/tuner/core/parser.py | 134 ++++ .../tuner/core/trainer.py} | 66 +- src/llmtuner/tuner/ppo/__init__.py | 1 + .../ppo.py => llmtuner/tuner/ppo/trainer.py} | 85 +-- src/llmtuner/tuner/ppo/utils.py | 37 ++ .../tuner/ppo/workflow.py} | 57 +- src/llmtuner/tuner/pt/__init__.py | 1 + .../tuner/pt/workflow.py} | 56 +- src/llmtuner/tuner/rm/__init__.py | 1 + src/llmtuner/tuner/rm/collator.py | 19 + src/llmtuner/tuner/rm/metric.py | 7 + src/llmtuner/tuner/rm/trainer.py | 38 ++ .../tuner/rm/workflow.py} | 49 +- src/llmtuner/tuner/sft/__init__.py | 1 + src/llmtuner/tuner/sft/metric.py | 51 ++ .../tuner/sft/trainer.py} | 65 +- .../tuner/sft/workflow.py} | 61 +- src/train_bash.py | 23 + src/utils/__init__.py | 17 - src/utils/common.py | 619 ------------------ src/utils/config.py | 312 --------- src/utils/other.py | 197 ------ src/utils/pairwise.py | 60 -- src/web_demo.py | 69 +- 57 files changed, 1999 insertions(+), 1816 deletions(-) create mode 100644 pyproject.toml create mode 100644 setup.py create mode 100644 src/llmtuner/__init__.py create mode 100644 src/llmtuner/api/__init__.py create mode 100644 src/llmtuner/api/app.py create mode 100644 src/llmtuner/api/protocol.py create mode 100644 src/llmtuner/dsets/__init__.py create mode 100644 src/llmtuner/dsets/callbacks.py create mode 100644 src/llmtuner/dsets/loader.py create mode 100644 src/llmtuner/dsets/preprocess.py rename src/{ => llmtuner/extras}/__init__.py (100%) create mode 100644 src/llmtuner/extras/callbacks.py create mode 100644 src/llmtuner/extras/constants.py create mode 100644 src/llmtuner/extras/logging.py create mode 100644 src/llmtuner/extras/misc.py create mode 100644 src/llmtuner/extras/ploting.py create mode 100644 src/llmtuner/extras/save_and_load.py rename src/{utils => llmtuner/extras}/template.py (100%) create mode 100644 src/llmtuner/hparams/__init__.py create mode 100644 src/llmtuner/hparams/data_args.py create mode 100644 src/llmtuner/hparams/finetuning_args.py create mode 100644 src/llmtuner/hparams/general_args.py create mode 100644 src/llmtuner/hparams/generating_args.py create mode 100644 src/llmtuner/hparams/model_args.py create mode 100644 src/llmtuner/tuner/__init__.py create mode 100644 src/llmtuner/tuner/core/__init__.py create mode 100644 src/llmtuner/tuner/core/adapter.py create mode 100644 src/llmtuner/tuner/core/loader.py create mode 100644 src/llmtuner/tuner/core/parser.py rename src/{utils/peft_trainer.py => llmtuner/tuner/core/trainer.py} (60%) create mode 100644 src/llmtuner/tuner/ppo/__init__.py rename src/{utils/ppo.py => llmtuner/tuner/ppo/trainer.py} (76%) create mode 100644 src/llmtuner/tuner/ppo/utils.py rename src/{train_ppo.py => llmtuner/tuner/ppo/workflow.py} (66%) create mode 100644 src/llmtuner/tuner/pt/__init__.py rename src/{train_pt.py => llmtuner/tuner/pt/workflow.py} (59%) create mode 100644 src/llmtuner/tuner/rm/__init__.py create mode 100644 src/llmtuner/tuner/rm/collator.py create mode 100644 src/llmtuner/tuner/rm/metric.py create mode 100644 src/llmtuner/tuner/rm/trainer.py rename src/{train_rm.py => llmtuner/tuner/rm/workflow.py} (64%) create mode 100644 src/llmtuner/tuner/sft/__init__.py create mode 100644 src/llmtuner/tuner/sft/metric.py rename src/{utils/seq2seq.py => llmtuner/tuner/sft/trainer.py} (51%) rename src/{train_sft.py => llmtuner/tuner/sft/workflow.py} (69%) create mode 100644 src/train_bash.py delete mode 100644 src/utils/__init__.py delete mode 100644 src/utils/common.py delete mode 100644 src/utils/config.py delete mode 100644 src/utils/other.py delete mode 100644 src/utils/pairwise.py diff --git a/README.md b/README.md index ec4bfe71..db0596e5 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ huggingface-cli login - Python 3.8+ and PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT and TRL - jieba, rouge-chinese and nltk (used at evaluation) -- gradio and mdtex2html (used in web_demo.py) +- gradio and matplotlib (used in web_demo.py) - uvicorn, fastapi and sse-starlette (used in api_demo.py) And **powerful GPUs**! @@ -137,7 +137,8 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \ ### (Continually) Pre-Training ```bash -CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \ +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage pt \ --model_name_or_path path_to_your_model \ --do_train \ --dataset wiki_demo \ @@ -158,7 +159,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \ ### Supervised Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft \ --model_name_or_path path_to_your_model \ --do_train \ --dataset alpaca_gpt4_en \ @@ -179,7 +181,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ ### Reward Model Training ```bash -CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \ +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage rm \ --model_name_or_path path_to_your_model \ --do_train \ --dataset comparison_gpt4_en \ @@ -199,7 +202,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \ ### PPO Training (RLHF) ```bash -CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \ +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage ppo \ --model_name_or_path path_to_your_model \ --do_train \ --dataset alpaca_gpt4_en \ @@ -222,7 +226,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \ ```bash accelerate config # configure the environment -accelerate launch src/train_XX.py # arguments (same as above) +accelerate launch src/train_bash.py # arguments (same as above) ```
Example configuration for full-tuning with DeepSpeed ZeRO-2 @@ -256,7 +260,8 @@ use_cpu: false ### Evaluation (BLEU and ROUGE_CHINESE) ```bash -CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage pt \ --model_name_or_path path_to_your_model \ --do_eval \ --dataset alpaca_gpt4_en \ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..638dd9c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 1fa830c0..e3b0b577 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,9 @@ sentencepiece jieba rouge-chinese nltk -gradio -mdtex2html +gradio>=3.36.0 uvicorn +pydantic==1.10.7 fastapi sse-starlette +matplotlib diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..930dabb2 --- /dev/null +++ b/setup.py @@ -0,0 +1,55 @@ +import os +import re +from setuptools import setup, find_packages + + +def get_version(): + with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f: + file_content = f.read() + pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__") + version, = re.findall(pattern, file_content) + return version + + +def get_requires(): + with open("requirements.txt", "r", encoding="utf-8") as f: + file_content = f.read() + lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] + return lines + + +def main(): + + setup( + name="llmtuner", + version=get_version(), + author="hiyouga", + author_email="hiyouga" "@" "buaa.edu.cn", + description="Easy-to-use fine-tuning framework using PEFT", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], + license="Apache 2.0 License", + url="https://github.com/hiyouga/LLaMA-Efficient-Tuning", + package_dir={"": "src"}, + packages=find_packages("src"), + python_requires=">=3.8.0", + install_requires=get_requires(), + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ] + ) + + +if __name__ == "__main__": + main() diff --git a/src/api_demo.py b/src/api_demo.py index a0a82321..f27df455 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -4,225 +4,11 @@ # Visit http://localhost:8000/docs for document. -import time -import torch import uvicorn -from threading import Thread -from pydantic import BaseModel, Field -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from contextlib import asynccontextmanager -from transformers import TextIteratorStreamer -from sse_starlette import EventSourceResponse -from typing import Any, Dict, List, Literal, Optional -from utils import ( - Template, - load_pretrained, - prepare_infer_args, - get_logits_processor -) - - -@asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory - yield - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - -app = FastAPI(lifespan=lifespan) - - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -class ModelCard(BaseModel): - id: str - object: Optional[str] = "model" - created: Optional[int] = Field(default_factory=lambda: int(time.time())) - owned_by: Optional[str] = "owner" - root: Optional[str] = None - parent: Optional[str] = None - permission: Optional[list] = [] - - -class ModelList(BaseModel): - object: Optional[str] = "list" - data: Optional[List[ModelCard]] = [] - - -class ChatMessage(BaseModel): - role: Literal["user", "assistant", "system"] - content: str - - -class DeltaMessage(BaseModel): - role: Optional[Literal["user", "assistant", "system"]] = None - content: Optional[str] = None - - -class ChatCompletionRequest(BaseModel): - model: str - messages: List[ChatMessage] - temperature: Optional[float] = None - top_p: Optional[float] = None - n: Optional[int] = 1 - max_tokens: Optional[int] = None - stream: Optional[bool] = False - - -class ChatCompletionResponseChoice(BaseModel): - index: int - message: ChatMessage - finish_reason: Literal["stop", "length"] - - -class ChatCompletionResponseStreamChoice(BaseModel): - index: int - delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] = None - - -class ChatCompletionResponseUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class ChatCompletionResponse(BaseModel): - id: Optional[str] = "chatcmpl-default" - object: Literal["chat.completion"] - created: Optional[int] = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseChoice] - usage: ChatCompletionResponseUsage - - -class ChatCompletionStreamResponse(BaseModel): - id: Optional[str] = "chatcmpl-default" - object: Literal["chat.completion.chunk"] - created: Optional[int] = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseStreamChoice] - - -@app.get("/v1/models", response_model=ModelList) -async def list_models(): - global model_args - model_card = ModelCard(id="gpt-3.5-turbo") - return ModelList(data=[model_card]) - - -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def create_chat_completion(request: ChatCompletionRequest): - global model, tokenizer, source_prefix, generating_args - - if request.messages[-1].role != "user": - raise HTTPException(status_code=400, detail="Invalid request") - query = request.messages[-1].content - - prev_messages = request.messages[:-1] - if len(prev_messages) > 0 and prev_messages[0].role == "system": - prefix = prev_messages.pop(0).content - else: - prefix = source_prefix - - history = [] - if len(prev_messages) % 2 == 0: - for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": - history.append([prev_messages[i].content, prev_messages[i+1].content]) - - inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt") - inputs = inputs.to(model.device) - - gen_kwargs = generating_args.to_dict() - gen_kwargs.update({ - "input_ids": inputs["input_ids"], - "temperature": request.temperature if request.temperature else gen_kwargs["temperature"], - "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], - "logits_processor": get_logits_processor() - }) - - if request.max_tokens: - gen_kwargs.pop("max_length", None) - gen_kwargs["max_new_tokens"] = request.max_tokens - - if request.stream: - generate = predict(gen_kwargs, request.model) - return EventSourceResponse(generate, media_type="text/event-stream") - - generation_output = model.generate(**gen_kwargs) - outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs, skip_special_tokens=True) - - usage = ChatCompletionResponseUsage( - prompt_tokens=len(inputs["input_ids"][0]), - completion_tokens=len(outputs), - total_tokens=len(inputs["input_ids"][0]) + len(outputs) - ) - - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop" - ) - - return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") - - -async def predict(gen_kwargs: Dict[str, Any], model_id: str): - global model, tokenizer - - streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - gen_kwargs["streamer"] = streamer - - thread = Thread(target=model.generate, kwargs=gen_kwargs) - thread.start() - - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant"), - finish_reason=None - ) - chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield chunk.json(exclude_unset=True, ensure_ascii=False) - - for new_text in streamer: - if len(new_text) == 0: - continue - - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=new_text), - finish_reason=None - ) - chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield chunk.json(exclude_unset=True, ensure_ascii=False) - - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(), - finish_reason="stop" - ) - chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield chunk.json(exclude_unset=True, ensure_ascii=False) - yield "[DONE]" +from llmtuner import create_app if __name__ == "__main__": - model_args, data_args, finetuning_args, generating_args = prepare_infer_args() - model, tokenizer = load_pretrained(model_args, finetuning_args) - - prompt_template = Template(data_args.prompt_template) - source_prefix = data_args.source_prefix if data_args.source_prefix else "" - + app = create_app() uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/src/cli_demo.py b/src/cli_demo.py index 752b42c8..1a32e35c 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -2,21 +2,15 @@ # Implements stream chat in command line for fine-tuned models. # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint - -from utils import ( - Template, - load_pretrained, - prepare_infer_args, - get_logits_processor -) from threading import Thread from transformers import TextIteratorStreamer +from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor + def main(): - - model_args, data_args, finetuning_args, generating_args = prepare_infer_args() - model, tokenizer = load_pretrained(model_args, finetuning_args) + model_args, data_args, finetuning_args, generating_args = get_infer_args() + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) source_prefix = data_args.source_prefix if data_args.source_prefix else "" diff --git a/src/export_model.py b/src/export_model.py index 71985180..3c1ffbbb 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -2,14 +2,12 @@ # Exports the fine-tuned model. # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model - -from utils import load_pretrained, prepare_args +from llmtuner import get_train_args, load_model_and_tokenizer def main(): - - model_args, _, training_args, finetuning_args = prepare_args(stage="sft") - model, tokenizer = load_pretrained(model_args, finetuning_args) + model_args, _, training_args, finetuning_args, _ = get_train_args() + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model.save_pretrained(training_args.output_dir, max_shard_size="10GB") tokenizer.save_pretrained(training_args.output_dir) print("model and tokenizer have been saved at:", training_args.output_dir) diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py new file mode 100644 index 00000000..bcbac9db --- /dev/null +++ b/src/llmtuner/__init__.py @@ -0,0 +1,7 @@ +from llmtuner.api import create_app +from llmtuner.extras.misc import get_logits_processor +from llmtuner.extras.template import Template +from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo + + +__version__ = "0.0.9" diff --git a/src/llmtuner/api/__init__.py b/src/llmtuner/api/__init__.py new file mode 100644 index 00000000..b3ce183a --- /dev/null +++ b/src/llmtuner/api/__init__.py @@ -0,0 +1 @@ +from llmtuner.api.app import create_app diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py new file mode 100644 index 00000000..3f31cb9a --- /dev/null +++ b/src/llmtuner/api/app.py @@ -0,0 +1,152 @@ +import uvicorn +from threading import Thread +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from transformers import TextIteratorStreamer +from contextlib import asynccontextmanager +from sse_starlette import EventSourceResponse +from typing import Any, Dict + +from llmtuner.tuner import get_infer_args, load_model_and_tokenizer +from llmtuner.extras.misc import get_logits_processor, torch_gc +from llmtuner.extras.template import Template +from llmtuner.api.protocol import ( + ModelCard, + ModelList, + ChatMessage, + DeltaMessage, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionResponseUsage +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + torch_gc() + + +def create_app(): + model_args, data_args, finetuning_args, generating_args = get_infer_args() + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + + prompt_template = Template(data_args.prompt_template) + source_prefix = data_args.source_prefix if data_args.source_prefix else "" + + app = FastAPI(lifespan=lifespan) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/v1/models", response_model=ModelList) + async def list_models(): + global model_args + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) + async def create_chat_completion(request: ChatCompletionRequest): + if request.messages[-1].role != "user": + raise HTTPException(status_code=400, detail="Invalid request") + query = request.messages[-1].content + + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == "system": + prefix = prev_messages.pop(0).content + else: + prefix = source_prefix + + history = [] + if len(prev_messages) % 2 == 0: + for i in range(0, len(prev_messages), 2): + if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i+1].content]) + + inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt") + inputs = inputs.to(model.device) + + gen_kwargs = generating_args.to_dict() + gen_kwargs.update({ + "input_ids": inputs["input_ids"], + "temperature": request.temperature if request.temperature else gen_kwargs["temperature"], + "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], + "logits_processor": get_logits_processor() + }) + + if request.max_tokens: + gen_kwargs.pop("max_length", None) + gen_kwargs["max_new_tokens"] = request.max_tokens + + if request.stream: + generate = predict(gen_kwargs, request.model) + return EventSourceResponse(generate, media_type="text/event-stream") + + generation_output = model.generate(**gen_kwargs) + outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + + usage = ChatCompletionResponseUsage( + prompt_tokens=len(inputs["input_ids"][0]), + completion_tokens=len(outputs), + total_tokens=len(inputs["input_ids"][0]) + len(outputs) + ) + + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop" + ) + + return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") + + async def predict(gen_kwargs: Dict[str, Any], model_id: str): + streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield chunk.json(exclude_unset=True, ensure_ascii=False) + + for new_text in streamer: + if len(new_text) == 0: + continue + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=new_text), + finish_reason=None + ) + chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield chunk.json(exclude_unset=True, ensure_ascii=False) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield "[DONE]" + + return app + + +if __name__ == "__main__": + app = create_app() + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py new file mode 100644 index 00000000..08aea3c3 --- /dev/null +++ b/src/llmtuner/api/protocol.py @@ -0,0 +1,73 @@ +import time +from pydantic import BaseModel, Field +from typing import List, Literal, Optional + + +class ModelCard(BaseModel): + id: str + object: Optional[str] = "model" + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + owned_by: Optional[str] = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = [] + + +class ModelList(BaseModel): + object: Optional[str] = "list" + data: Optional[List[ModelCard]] = [] + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system"] + content: str + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stream: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: Optional[str] = "chatcmpl-default" + object: Literal["chat.completion"] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: ChatCompletionResponseUsage + + +class ChatCompletionStreamResponse(BaseModel): + id: Optional[str] = "chatcmpl-default" + object: Literal["chat.completion.chunk"] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] diff --git a/src/llmtuner/dsets/__init__.py b/src/llmtuner/dsets/__init__.py new file mode 100644 index 00000000..7667c89c --- /dev/null +++ b/src/llmtuner/dsets/__init__.py @@ -0,0 +1,2 @@ +from llmtuner.dsets.loader import get_dataset +from llmtuner.dsets.preprocess import preprocess_dataset diff --git a/src/llmtuner/dsets/callbacks.py b/src/llmtuner/dsets/callbacks.py new file mode 100644 index 00000000..cb013961 --- /dev/null +++ b/src/llmtuner/dsets/callbacks.py @@ -0,0 +1,63 @@ +import os +import json +import time +from datetime import timedelta + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments +) + + +class LogCallback(TrainerCallback): + + def __init__(self, runner=None): + self.runner = runner + self.start_time = time.time() + self.tracker = {} + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + r""" + Event called at the beginning of a training step. If using gradient accumulation, one training step + might take several inputs. + """ + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: + r""" + Event called after logging the last logs. + """ + if "loss" not in state.log_history[-1]: + return + cur_time = time.time() + cur_steps = state.log_history[-1].get("step") + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 + remaining_steps = state.max_steps - cur_steps + remaining_time = remaining_steps * avg_time_per_step + self.tracker = { + "current_steps": cur_steps, + "total_steps": state.max_steps, + "loss": state.log_history[-1].get("loss", None), + "reward": state.log_history[-1].get("reward", None), + "learning_rate": state.log_history[-1].get("learning_rate", None), + "epoch": state.log_history[-1].get("epoch", None), + "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, + "elapsed_time": str(timedelta(seconds=int(elapsed_time))), + "remaining_time": str(timedelta(seconds=int(remaining_time))) + } + os.makedirs(args.output_dir, exist_ok=True) + with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: + f.write(json.dumps(self.tracker) + "\n") diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py new file mode 100644 index 00000000..005cbee5 --- /dev/null +++ b/src/llmtuner/dsets/loader.py @@ -0,0 +1,106 @@ +import os +import hashlib +from typing import List + +from datasets import Dataset, concatenate_datasets, load_dataset + +from llmtuner.extras.logging import get_logger +from llmtuner.hparams import ModelArguments, DataArguments + + +logger = get_logger(__name__) + + +def get_dataset( + model_args: ModelArguments, + data_args: DataArguments +) -> Dataset: + + def checksum(file_path, hash): + with open(file_path, "rb") as datafile: + binary_data = datafile.read() + sha1 = hashlib.sha1(binary_data).hexdigest() + if sha1 != hash: + logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) + + ext2type = { + "csv": "csv", + "json": "json", + "jsonl": "json", + "txt": "text" + } + + max_samples = data_args.max_samples + all_datasets: List[Dataset] = [] # support multiple datasets + + for dataset_attr in data_args.dataset_list: + + logger.info("Loading dataset {}...".format(dataset_attr)) + + if dataset_attr.load_from == "hf_hub": + data_path = dataset_attr.dataset_name + data_files = None + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_files = None + elif dataset_attr.load_from == "file": + data_path = None + data_files: List[str] = [] + + if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) + + if data_path is None: + data_path = ext2type.get(data_files[0].split(".")[-1], None) + else: + assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." + elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) + data_path = ext2type.get(data_files[0].split(".")[-1], None) + else: + raise ValueError("File not found.") + + assert data_path, "File extension must be txt, csv, json or jsonl." + + if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None: + checksum(data_files[0], dataset_attr.dataset_sha1) + else: + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.") + else: + raise NotImplementedError + + raw_datasets = load_dataset( + data_path, + data_files=data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None + ) + dataset = raw_datasets[data_args.split] + + if max_samples is not None: + max_samples_temp = min(len(dataset), max_samples) + dataset = dataset.select(range(max_samples_temp)) + + dummy_data = [None] * len(dataset) + prefix_data = [dataset_attr.source_prefix] * len(dataset) + for column_name, target_name in [ + ("prompt_column", "prompt"), + ("query_column", "query"), + ("response_column", "response"), + ("history_column", "history") + ]: # every dataset will have 4 columns same as each other + if getattr(dataset_attr, column_name) != target_name: + if getattr(dataset_attr, column_name): + dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) + else: # None or empty string + dataset = dataset.add_column(target_name, dummy_data) + dataset = dataset.add_column("prefix", prefix_data) + all_datasets.append(dataset) + + if len(data_args.dataset_list) == 1: + all_datasets = all_datasets[0] + else: + all_datasets = concatenate_datasets(all_datasets) + + return all_datasets diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py new file mode 100644 index 00000000..4eb912c1 --- /dev/null +++ b/src/llmtuner/dsets/preprocess.py @@ -0,0 +1,172 @@ +from typing import Literal +from itertools import chain +from transformers import Seq2SeqTrainingArguments +from transformers.tokenization_utils import PreTrainedTokenizer + +from datasets import Dataset + +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.template import Template +from llmtuner.hparams import DataArguments + + +def preprocess_dataset( + dataset: Dataset, + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, + training_args: Seq2SeqTrainingArguments, + stage: Literal["pt", "sft", "rm", "ppo"] +) -> Dataset: + + column_names = list(dataset.column_names) + prompt_template = Template(data_args.prompt_template) + + # support question with a single answer or multiple answers + def get_dialog(examples): + for i in range(len(examples["prompt"])): + if examples["prompt"][i] and examples["response"][i]: + query, answer = examples["prompt"][i], examples["response"][i] + query = query + "\n" + examples["query"][i] if examples["query"][i] else query + prefix = examples["prefix"][i] if examples["prefix"][i] else "" + dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) + yield dialog + + def preprocess_pretrain_dataset(examples): + # build grouped texts with format ` X1 X2 X3 ...` (without ) + text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] + concatenated_ids = list(chain(*text_ids)) + total_length = len(concatenated_ids) + block_size = data_args.max_source_length - 1 + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of max_source_length + result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] + for i in range(0, total_length, block_size)] + return { + "input_ids": result, + "labels": result.copy() + } + + def preprocess_supervised_dataset(examples): + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for input with history, we build multiple input-label pairs just like: + # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 + model_inputs = {"input_ids": [], "labels": []} + max_length = data_args.max_source_length + data_args.max_target_length + + for dialog in get_dialog(examples): + input_ids, labels = [], [] + + for i in range(len(dialog) // 2): + source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) + target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) + + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[:data_args.max_source_length] + if len(target_ids) > data_args.max_target_length - 1: # eos token + target_ids = target_ids[:data_args.max_target_length - 1] + + if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: + break + + input_ids += source_ids + target_ids + [tokenizer.eos_token_id] + labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) + + return model_inputs + + def preprocess_unsupervised_dataset(examples): + # build inputs with format ` X` and labels with format ` Y` + model_inputs = {"input_ids": [], "labels": []} + + for dialog in get_dialog(examples): + prompt, answer = "".join(dialog[:-1]), dialog[-1] + + source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) + target_ids = tokenizer.encode(text=answer, add_special_tokens=True) + + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[:data_args.max_source_length] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[:data_args.max_target_length] + + model_inputs["input_ids"].append(source_ids) + model_inputs["labels"].append(target_ids) + + return model_inputs + + def preprocess_pairwise_dataset(examples): + # build input pairs with format ` X Y1 ` and ` X Y2 ` + model_inputs = {"accept_ids": [], "reject_ids": []} + for dialog in get_dialog(examples): + prompt, answer = "".join(dialog[:-1]), dialog[-1] + + source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) + accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) + reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) + + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[:data_args.max_source_length] + if len(accept_ids) > data_args.max_target_length - 1: # eos token + accept_ids = accept_ids[:data_args.max_target_length - 1] + if len(reject_ids) > data_args.max_target_length - 1: # eos token + reject_ids = reject_ids[:data_args.max_target_length - 1] + + accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id] + reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id] + + model_inputs["accept_ids"].append(accept_ids) + model_inputs["reject_ids"].append(reject_ids) + return model_inputs + + def print_supervised_dataset_example(example): + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format( + tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]], + skip_special_tokens=False) + )) + + def print_pairwise_dataset_example(example): + print("accept_ids:\n{}".format(example["accept_ids"])) + print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) + print("reject_ids:\n{}".format(example["reject_ids"])) + print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) + + def print_unsupervised_dataset_example(example): + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + + if stage == "pt": + preprocess_function = preprocess_pretrain_dataset + elif stage == "sft": + preprocess_function = preprocess_unsupervised_dataset \ + if training_args.predict_with_generate else preprocess_supervised_dataset + elif stage == "rm": + preprocess_function = preprocess_pairwise_dataset + elif stage == "ppo": + preprocess_function = preprocess_unsupervised_dataset + + with training_args.main_process_first(desc="dataset map pre-processing"): + dataset = dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset" + ) + + if stage == "pt": + print_unsupervised_dataset_example(dataset[0]) + elif stage == "sft": + print_supervised_dataset_example(dataset[0]) + elif stage == "rm": + print_pairwise_dataset_example(dataset[0]) + elif stage == "ppo": + print_unsupervised_dataset_example(dataset[0]) + + return dataset diff --git a/src/__init__.py b/src/llmtuner/extras/__init__.py similarity index 100% rename from src/__init__.py rename to src/llmtuner/extras/__init__.py diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py new file mode 100644 index 00000000..69c45e17 --- /dev/null +++ b/src/llmtuner/extras/callbacks.py @@ -0,0 +1,72 @@ +import os +import json +import time +from datetime import timedelta + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments +) +from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + + +class LogCallback(TrainerCallback): + + def __init__(self, runner=None): + self.runner = runner + self.tracker = {} + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + r""" + Event called at the beginning of training. + """ + self.start_time = time.time() + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + r""" + Event called at the beginning of a training step. If using gradient accumulation, one training step + might take several inputs. + """ + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: + r""" + Event called after logging the last logs. + """ + if "current_steps" not in state.log_history[-1]: + return + cur_time = time.time() + cur_steps = state.log_history[-1].get("step") + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 + remaining_steps = state.max_steps - cur_steps + remaining_time = remaining_steps * avg_time_per_step + self.tracker = { + "current_steps": cur_steps, + "total_steps": state.max_steps, + "loss": state.log_history[-1].get("loss", None), + "eval_loss": state.log_history[-1].get("eval_loss", None), + "predict_loss": state.log_history[-1].get("predict_loss", None), + "reward": state.log_history[-1].get("reward", None), + "learning_rate": state.log_history[-1].get("learning_rate", None), + "epoch": state.log_history[-1].get("epoch", None), + "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, + "elapsed_time": str(timedelta(seconds=int(elapsed_time))), + "remaining_time": str(timedelta(seconds=int(remaining_time))) + } + os.makedirs(args.output_dir, exist_ok=True) + with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: + f.write(json.dumps(self.tracker) + "\n") diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py new file mode 100644 index 00000000..ab8971bb --- /dev/null +++ b/src/llmtuner/extras/constants.py @@ -0,0 +1,7 @@ +IGNORE_INDEX = -100 + +VALUE_HEAD_FILE_NAME = "value_head.bin" + +FINETUNING_ARGS_NAME = "finetuning_args.json" + +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py new file mode 100644 index 00000000..231acf4a --- /dev/null +++ b/src/llmtuner/extras/logging.py @@ -0,0 +1,18 @@ +import sys +import logging + + +def get_logger(name: str) -> logging.Logger: + + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S" + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + return logger diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py new file mode 100644 index 00000000..40c8d337 --- /dev/null +++ b/src/llmtuner/extras/misc.py @@ -0,0 +1,105 @@ +import torch +from typing import List, Optional + +from transformers.modeling_utils import PreTrainedModel +from transformers.generation.utils import LogitsProcessorList +from transformers.generation.logits_process import LogitsProcessor + +from llmtuner.extras.constants import LAYERNORM_NAMES + + +class AverageMeter: + r""" + Computes and stores the average and current value. + """ + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +# Avoid runtime error in model.generate(do_sample=True). +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 0] = 1.0 + return scores + + +def get_logits_processor() -> LogitsProcessorList: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + return logits_processor + + +def print_trainable_params(model: torch.nn.Module) -> None: + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param)) + + +# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 +# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 +def prepare_model_for_training( + model: PreTrainedModel, + finetuning_type: str, + output_embedding_layer_name: Optional[str] = "lm_head", + use_gradient_checkpointing: Optional[bool] = True, + layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES +) -> PreTrainedModel: + + for name, param in model.named_parameters(): + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + param.data = param.data.to(torch.float32) + + if use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model.gradient_checkpointing_enable() + model.config.use_cache = False # turn off when gradient checkpointing is enabled + + if finetuning_type != "full" and hasattr(model, output_embedding_layer_name): + output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name) + input_dtype = output_embedding_layer.weight.dtype + + class CastOutputToFloat(torch.nn.Sequential): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.to(input_dtype)).to(torch.float32) + + setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) + + return model + +def torch_gc() -> None: + r""" + Collects GPU memory. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py new file mode 100644 index 00000000..fb11a290 --- /dev/null +++ b/src/llmtuner/extras/ploting.py @@ -0,0 +1,50 @@ +import os +import json +import matplotlib.pyplot as plt +from typing import List, Optional +from transformers.trainer import TRAINER_STATE_NAME + +from llmtuner.extras.logging import get_logger + + +logger = get_logger(__name__) + + +def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: + r""" + EMA implementation according to TensorBoard. + """ + last = scalars[0] + smoothed = list() + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: + + with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: + data = json.load(f) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), label="smoothed") + plt.title("training {} of {}".format(key, save_dictionary)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) + print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) diff --git a/src/llmtuner/extras/save_and_load.py b/src/llmtuner/extras/save_and_load.py new file mode 100644 index 00000000..fd4a8165 --- /dev/null +++ b/src/llmtuner/extras/save_and_load.py @@ -0,0 +1,49 @@ +import os +import torch +from typing import Dict + +from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.modeling_utils import load_sharded_checkpoint + +from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME +from llmtuner.extras.logging import get_logger + + +logger = get_logger(__name__) + + +def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters + state_dict = model.state_dict() + filtered_state_dict = {} + + for k, v in model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k].cpu().clone().detach() + + return filtered_state_dict + + +def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: + weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) + if os.path.exists(weights_file): + model_state_dict = torch.load(weights_file, map_location="cpu") + model.load_state_dict(model_state_dict, strict=False) # skip missing keys + elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)): + load_sharded_checkpoint(model, checkpoint_dir, strict=False) + else: + logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) + return False + return True + + +def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: + valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) + if not os.path.exists(valuehead_file): + logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) + return False + valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") + model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) + model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) + model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) + model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) + return True diff --git a/src/utils/template.py b/src/llmtuner/extras/template.py similarity index 100% rename from src/utils/template.py rename to src/llmtuner/extras/template.py diff --git a/src/llmtuner/hparams/__init__.py b/src/llmtuner/hparams/__init__.py new file mode 100644 index 00000000..0fabfa33 --- /dev/null +++ b/src/llmtuner/hparams/__init__.py @@ -0,0 +1,5 @@ +from .data_args import DataArguments +from .finetuning_args import FinetuningArguments +from .general_args import GeneralArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py new file mode 100644 index 00000000..df4c0557 --- /dev/null +++ b/src/llmtuner/hparams/data_args.py @@ -0,0 +1,119 @@ +import os +import json +from typing import List, Optional +from dataclasses import dataclass, field + + +@dataclass +class DatasetAttr: + + load_from: str + dataset_name: Optional[str] = None + dataset_sha1: Optional[str] = None + source_prefix: Optional[str] = None + + def __repr__(self) -> str: + return self.dataset_name + + def __post_init__(self): + self.prompt_column = "instruction" + self.query_column = "input" + self.response_column = "output" + self.history_column = None + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + dataset: Optional[str] = field( + default="alpaca_zh", + metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} + ) + dataset_dir: Optional[str] = field( + default="data", + metadata={"help": "The name of the folder containing datasets."} + ) + split: Optional[str] = field( + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."} + ) + overwrite_cache: Optional[bool] = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."} + ) + max_source_length: Optional[int] = field( + default=512, + metadata={"help": "The maximum total input sequence length after tokenization."} + ) + max_target_length: Optional[int] = field( + default=512, + metadata={"help": "The maximum total output sequence length after tokenization."} + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} + ) + ignore_pad_token_for_loss: Optional[bool] = field( + default=True, + metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} + ) + source_prefix: Optional[str] = field( + default=None, + metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} + ) + dev_ratio: Optional[float] = field( + default=0, + metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} + ) + prompt_template: Optional[str] = field( + default="default", + metadata={"help": "Which template to use for constructing prompts in training and inference."} + ) + + def init_for_training(self): # support mixing multiple datasets + dataset_names = [ds.strip() for ds in self.dataset.split(",")] + with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: + dataset_info = json.load(f) + + if self.source_prefix is not None: + prefix_list = self.source_prefix.split("|") + prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list + assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1." + else: + prefix_list = [None] * len(dataset_names) + + self.dataset_list: List[DatasetAttr] = [] + for i, name in enumerate(dataset_names): + if name not in dataset_info: + raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) + + if "hf_hub_url" in dataset_info[name]: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + else: + dataset_attr = DatasetAttr( + "file", + dataset_name=dataset_info[name]["file_name"], + dataset_sha1=dataset_info[name].get("file_sha1", None) + ) + + dataset_attr.source_prefix = prefix_list[i] + + if "columns" in dataset_info[name]: + dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) + dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) + dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) + + self.dataset_list.append(dataset_attr) \ No newline at end of file diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py new file mode 100644 index 00000000..6f01ef29 --- /dev/null +++ b/src/llmtuner/hparams/finetuning_args.py @@ -0,0 +1,78 @@ +import json +from typing import Literal, Optional +from dataclasses import asdict, dataclass, field + + +@dataclass +class FinetuningArguments: + """ + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."} + ) + num_hidden_layers: Optional[int] = field( + default=32, + metadata={"help": "Number of decoder blocks in the model. \ + LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ + BLOOM choices: [\"24\", \"30\", \"70\"], \ + Falcon choices: [\"32\", \"60\"], \ + Baichuan choices: [\"32\"]"} + ) + num_layer_trainable: Optional[int] = field( + default=3, + metadata={"help": "Number of trainable layers for Freeze fine-tuning."} + ) + name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( + default="mlp", + metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ + LLaMA choices: [\"mlp\", \"self_attn\"], \ + BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ + Baichuan choices: [\"mlp\", \"self_attn\"]"} + ) + lora_rank: Optional[int] = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} + ) + lora_alpha: Optional[float] = field( + default=32.0, + metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} + ) + lora_dropout: Optional[float] = field( + default=0.1, + metadata={"help": "Dropout rate for the LoRA fine-tuning."} + ) + lora_target: Optional[str] = field( + default="q_proj,v_proj", + metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ + Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} + ) + + def __post_init__(self): + if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA + self.lora_target = [target.strip() for target in self.lora_target.split(",")] + + if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)] + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] + + self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] + + assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." + + def save_to_json(self, json_path: str): + """Saves the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """Creates an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) diff --git a/src/llmtuner/hparams/general_args.py b/src/llmtuner/hparams/general_args.py new file mode 100644 index 00000000..a97a4935 --- /dev/null +++ b/src/llmtuner/hparams/general_args.py @@ -0,0 +1,13 @@ +from typing import Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class GeneralArguments: + """ + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."} + ) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py new file mode 100644 index 00000000..e25ff4b9 --- /dev/null +++ b/src/llmtuner/hparams/generating_args.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional +from dataclasses import asdict, dataclass, field + + +@dataclass +class GeneratingArguments: + """ + Arguments pertaining to specify the decoding parameters. + """ + do_sample: Optional[bool] = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} + ) + temperature: Optional[float] = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."} + ) + top_p: Optional[float] = field( + default=0.7, + metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} + ) + top_k: Optional[int] = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} + ) + num_beams: Optional[int] = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."} + ) + max_length: Optional[int] = field( + default=None, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} + ) + max_new_tokens: Optional[int] = field( + default=512, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} + ) + repetition_penalty: Optional[float] = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} + ) + length_penalty: Optional[float] = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} + ) + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get("max_new_tokens", None): + args.pop("max_length", None) + return args diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py new file mode 100644 index 00000000..253d9839 --- /dev/null +++ b/src/llmtuner/hparams/model_args.py @@ -0,0 +1,72 @@ +import torch +from typing import Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} + ) + use_fast_tokenizer: Optional[bool] = field( + default=False, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} + ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={"help": "Will use the token generated when running `huggingface-cli login`."} + ) + model_revision: Optional[str] = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} + ) + padding_side: Optional[Literal["left", "right"]] = field( + default="left", + metadata={"help": "The side on which the model should have padding applied."} + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model."} + ) + quantization_type: Optional[Literal["fp4", "nf4"]] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."} + ) + double_quantization: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use double quantization in int4 training or not."} + ) + compute_dtype: Optional[torch.dtype] = field( + default=None, + metadata={"help": "Used in quantization configs. Do not specify this argument manually."} + ) + checkpoint_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + ) + plot_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to plot the training loss after fine-tuning or not."} + ) + + def __post_init__(self): + if self.checkpoint_dir is not None: # support merging multiple lora weights + self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] + + if self.quantization_bit is not None: + assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." diff --git a/src/llmtuner/tuner/__init__.py b/src/llmtuner/tuner/__init__.py new file mode 100644 index 00000000..c329f39a --- /dev/null +++ b/src/llmtuner/tuner/__init__.py @@ -0,0 +1,5 @@ +from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer +from llmtuner.tuner.pt import run_pt +from llmtuner.tuner.sft import run_sft +from llmtuner.tuner.rm import run_rm +from llmtuner.tuner.ppo import run_ppo diff --git a/src/llmtuner/tuner/core/__init__.py b/src/llmtuner/tuner/core/__init__.py new file mode 100644 index 00000000..bd1c5cf0 --- /dev/null +++ b/src/llmtuner/tuner/core/__init__.py @@ -0,0 +1,2 @@ +from llmtuner.tuner.core.parser import get_train_args, get_infer_args +from llmtuner.tuner.core.loader import load_model_and_tokenizer diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py new file mode 100644 index 00000000..5fddeb99 --- /dev/null +++ b/src/llmtuner/tuner/core/adapter.py @@ -0,0 +1,94 @@ +import os +import torch + +from transformers.modeling_utils import PreTrainedModel +from peft import ( + PeftModel, + TaskType, + LoraConfig, + get_peft_model +) +from peft.utils import CONFIG_NAME, WEIGHTS_NAME + +from llmtuner.extras.logging import get_logger +from llmtuner.extras.save_and_load import load_trainable_params +from llmtuner.hparams import ModelArguments, FinetuningArguments + + +logger = get_logger(__name__) + + +def init_adapter( + model: PreTrainedModel, + model_args: ModelArguments, + finetuning_args: FinetuningArguments, + is_trainable: bool, + is_mergeable: bool +) -> PreTrainedModel: + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA training. + + Note that the trainable parameters must be cast to float32. + """ + + if finetuning_args.finetuning_type == "none" and is_trainable: + raise ValueError("You cannot use finetuning_type=none while training.") + + if finetuning_args.finetuning_type == "full": + logger.info("Fine-tuning method: Full") + model = model.float() + + if finetuning_args.finetuning_type == "freeze": + logger.info("Fine-tuning method: Freeze") + + for name, param in model.named_parameters(): + if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): + param.requires_grad_(False) + else: + param.data = param.data.to(torch.float32) + + if model_args.checkpoint_dir is not None: + assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." + + if finetuning_args.finetuning_type == "lora": + logger.info("Fine-tuning method: LoRA") + latest_checkpoint = None + + if model_args.checkpoint_dir is not None: + assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \ + "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]) + assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ + "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." + + if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights + checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] + else: + checkpoints_to_merge = model_args.checkpoint_dir + + for checkpoint in checkpoints_to_merge: + model = PeftModel.from_pretrained(model, checkpoint) + model = model.merge_and_unload() + + if len(checkpoints_to_merge) > 0: + logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) + + if latest_checkpoint is not None: # resume lora training or quantized inference + model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) + + if is_trainable and latest_checkpoint is None: # create new lora weights while training + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=finetuning_args.lora_target + ) + model = get_peft_model(model, lora_config) + + if model_args.checkpoint_dir is not None: + logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) + + return model diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py new file mode 100644 index 00000000..a6b1c0f5 --- /dev/null +++ b/src/llmtuner/tuner/core/loader.py @@ -0,0 +1,151 @@ +import os +import torch +from typing import Literal, Optional, Tuple + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig +) +from transformers.utils import check_min_version +from transformers.utils.versions import require_version +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer +from trl import AutoModelForCausalLMWithValueHead + +from llmtuner.extras.logging import get_logger +from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params +from llmtuner.extras.save_and_load import load_valuehead_params +from llmtuner.hparams import ModelArguments, FinetuningArguments +from llmtuner.tuner.core.adapter import init_adapter + + +logger = get_logger(__name__) + + +check_min_version("4.29.1") +require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") +require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") +require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") +require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4") + + +def load_model_and_tokenizer( + model_args: ModelArguments, + finetuning_args: FinetuningArguments, + is_trainable: Optional[bool] = False, + stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + r""" + Loads pretrained model and tokenizer. + + Support both training and inference. + """ + if (not is_trainable) and model_args.checkpoint_dir is None: + logger.warning("Checkpoint is not found at evaluation, load the original model.") + finetuning_args = FinetuningArguments(finetuning_type="none") + + assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ + "RM and PPO training can only be performed with the LoRA method." + + config_kwargs = { + "trust_remote_code": True, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + padding_side=model_args.padding_side, + **config_kwargs + ) + if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version) + tokenizer.pad_token_id = 0 # set as the token + + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + is_mergeable = True + + # Quantization configurations (using bitsandbytes library). + if model_args.quantization_bit is not None: + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["load_in_8bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_threshold=6.0 + ) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1") + require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3") + require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") + config_kwargs["load_in_4bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + + is_mergeable = False + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + if not is_trainable: # `device_map=auto` should be used for inference only + config_kwargs["device_map"] = "auto" + + if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": + model_to_load = model_args.checkpoint_dir[0] + else: + model_to_load = model_args.model_name_or_path + + # Load and prepare pretrained models (without valuehead). + model = AutoModelForCausalLM.from_pretrained( + model_to_load, + config=config, + torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, + low_cpu_mem_usage=True, + **config_kwargs + ) + + # Register auto class to save the custom code files. + if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map: + config.__class__.register_for_auto_class() + if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: + tokenizer.__class__.register_for_auto_class() + if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map: + model.__class__.register_for_auto_class() + + # Initialize adapters + model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model + model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) + + if stage == "rm" or stage == "ppo": # add value head + model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + + if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model + logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") + if load_valuehead_params(model, model_args.checkpoint_dir[-1]): + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "reward_head_weight"), + "summary.bias": getattr(model, "reward_head_bias") + }) + + if stage == "ppo": # load reward model + assert is_trainable, "PPO stage cannot be performed at evaluation." + assert model_args.reward_model is not None, "Reward model is necessary for PPO training." + logger.info("Load reward model from {}".format(model_args.reward_model)) + model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) + assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." + + if not is_trainable: + model.requires_grad_(False) # fix all model params + model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 + + print_trainable_params(model) + + return model, tokenizer diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py new file mode 100644 index 00000000..186efeea --- /dev/null +++ b/src/llmtuner/tuner/core/parser.py @@ -0,0 +1,134 @@ +import os +import sys +import torch +import datasets +import transformers +from typing import Any, Dict, Optional, Tuple +from transformers import HfArgumentParser, Seq2SeqTrainingArguments + +from llmtuner.extras.logging import get_logger +from llmtuner.hparams import ( + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments +) + + +logger = get_logger(__name__) + + +def get_train_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: + + parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments)) + + if args is not None: + model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses() + + # Setup logging + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) + data_args.init_for_training() + + assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ + "`predict_with_generate` cannot be set as True at PT, RM and PPO stages." + + assert not (training_args.do_train and training_args.predict_with_generate), \ + "`predict_with_generate` cannot be set as True while training." + + assert (not training_args.do_predict) or training_args.predict_with_generate, \ + "Please enable `predict_with_generate` to save model predictions." + + assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ + "Quantization is only compatible with the LoRA method." + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + else: + assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ + "Quantized model only accepts a single checkpoint." + + if model_args.quantization_bit is not None and (not training_args.do_train): + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + if training_args.do_train and (not training_args.fp16): + logger.warning("We recommend enable fp16 mixed precision training.") + + if data_args.prompt_template == "default": + logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") + + if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: + logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") + training_args.ddp_find_unused_parameters = False + + training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning + + if model_args.quantization_bit is not None: + if training_args.fp16: + model_args.compute_dtype = torch.float16 + elif training_args.bf16: + model_args.compute_dtype = torch.bfloat16 + else: + model_args.compute_dtype = torch.float32 + + # Log on each process the small summary: + logger.info( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, general_args + + +def get_infer_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: + + parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments)) + + if args is not None: + model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() + + assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ + "Quantization is only compatible with the LoRA method." + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + else: + assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ + "Quantized model only accepts a single checkpoint." + + if data_args.prompt_template == "default": + logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") + + return model_args, data_args, finetuning_args, generating_args diff --git a/src/utils/peft_trainer.py b/src/llmtuner/tuner/core/trainer.py similarity index 60% rename from src/utils/peft_trainer.py rename to src/llmtuner/tuner/core/trainer.py index 94dd11a6..8b057ff3 100644 --- a/src/utils/peft_trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -1,76 +1,20 @@ import os -import json -import time import torch from typing import Dict, Optional -from datetime import timedelta - -from transformers import ( - Seq2SeqTrainer, - TrainerCallback, - TrainerControl, - TrainerState, - TrainingArguments -) +from transformers import Seq2SeqTrainer from transformers.trainer import TRAINING_ARGS_NAME from transformers.modeling_utils import unwrap_model -from .config import FinetuningArguments - -from .other import ( - get_logger, - get_state_dict, - load_trainable_params, - load_valuehead_params, - FINETUNING_ARGS_NAME, - VALUE_HEAD_FILE_NAME -) +from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME +from llmtuner.extras.logging import get_logger +from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params +from llmtuner.hparams import FinetuningArguments logger = get_logger(__name__) -class LogCallback(TrainerCallback): - r""" - TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class. - The on_log function primarily collects process parameters during training, such as training loss, learning rate, - and training epochs, as well as progress parameters like the current percentage progress and estimated remaining - time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization - purposes. - """ - - def __init__(self): - self.start_time = time.time() - - def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: - r""" - Event called after logging the last logs. - """ - if "loss" not in state.log_history[-1]: - return - cur_time = time.time() - cur_steps = state.log_history[-1].get("step") - elapsed_time = cur_time - self.start_time - avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 - remaining_steps = state.max_steps - cur_steps - remaining_time = remaining_steps * avg_time_per_step - log_dict = { - "current_steps": cur_steps, - "total_steps": state.max_steps, - "loss": state.log_history[-1].get("loss", None), - "reward": state.log_history[-1].get("reward", None), - "learning_rate": state.log_history[-1].get("learning_rate", None), - "epoch": state.log_history[-1].get("epoch", None), - "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, - "elapsed_time": str(timedelta(seconds=int(elapsed_time))), - "remaining_time": str(timedelta(seconds=int(remaining_time))) - } - os.makedirs(args.output_dir, exist_ok=True) - with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f: - f.write(json.dumps(log_dict) + "\n") - - class PeftTrainer(Seq2SeqTrainer): r""" Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. diff --git a/src/llmtuner/tuner/ppo/__init__.py b/src/llmtuner/tuner/ppo/__init__.py new file mode 100644 index 00000000..11519bab --- /dev/null +++ b/src/llmtuner/tuner/ppo/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.ppo.workflow import run_ppo diff --git a/src/utils/ppo.py b/src/llmtuner/tuner/ppo/trainer.py similarity index 76% rename from src/utils/ppo.py rename to src/llmtuner/tuner/ppo/trainer.py index 477dce59..8612ecc2 100644 --- a/src/utils/ppo.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -2,77 +2,43 @@ import os import math import torch from tqdm import tqdm -from typing import Callable, Dict, List, Literal, Optional, Tuple +from typing import Callable, Dict, List, Optional -from transformers import Seq2SeqTrainingArguments, TrainerState +from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl from transformers.modeling_utils import PreTrainedModel -from trl import PPOTrainer, AutoModelForCausalLMWithValueHead +from trl import PPOTrainer from trl.core import LengthSampler -from .peft_trainer import PeftTrainer, LogCallback - -from .config import FinetuningArguments - -from .other import ( - AverageMeter, - get_logger, - get_logits_processor -) +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.logging import get_logger +from llmtuner.extras.misc import AverageMeter, get_logits_processor +from llmtuner.hparams import FinetuningArguments +from llmtuner.tuner.core.trainer import PeftTrainer +from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model logger = get_logger(__name__) -def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: - if target == "reward": # save default head temporarily - valuehead_state_dict = model.v_head.state_dict() - setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) - setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"]) - - model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active - model.v_head.load_state_dict({ - "summary.weight": getattr(model, "{}_head_weight".format(target)), - "summary.bias": getattr(model, "{}_head_bias".format(target)) - }) - - -def cast_layernorm_dtype( - model: AutoModelForCausalLMWithValueHead, - layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings - layer_norm_params: Optional[Dict[str, torch.Tensor]] = None -) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: - - layer_norm_state_dict = {} - - for name, param in model.named_parameters(): - if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): - if layer_norm_params is not None: - param.data = layer_norm_params[name] # restore float32 weights - else: - layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability - param.data = param.data.to(torch.float16) - - return model, layer_norm_state_dict - - class PPOPeftTrainer(PPOTrainer, PeftTrainer): r""" Inherits PPOTrainer. """ def __init__( - self, - training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: List[LogCallback], - **kwargs + self, + training_args: Seq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: List[LogCallback], + **kwargs ): PPOTrainer.__init__(self, **kwargs) self.args = training_args self.finetuning_args = finetuning_args self.log_callback = callbacks[0] self.state = TrainerState() + self.control = TrainerControl() self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer def ppo_train(self, max_target_length: int) -> None: @@ -117,8 +83,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): steps_trained = 0 loss_meter = AverageMeter() reward_meter = AverageMeter() + self.log_callback.on_train_begin(self.args, self.state, self.control) - for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()): + for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): for _ in range(self.config.gradient_accumulation_steps): @@ -158,6 +125,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) + if self.control.should_epoch_stop or self.control.should_training_stop: + break + if steps_trained == len_dataloader: dataiter = iter(self.dataloader) steps_trained = 0 @@ -172,20 +142,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): print(logs) logs["step"] = step self.state.log_history.append(logs) - self.log_callback.on_log(self.args, self.state, None) + self.log_callback.on_log(self.args, self.state, self.control) loss_meter.reset() reward_meter.reset() if (step+1) % self.args.save_steps == 0: # save checkpoint self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) + if self.control.should_training_stop: + break + @torch.no_grad() def generate( - self, - inputs: Dict[str, torch.Tensor], - length_sampler: Optional[Callable] = None, - return_prompt: Optional[bool] = True, - **generation_kwargs, + self, + inputs: Dict[str, torch.Tensor], + length_sampler: Optional[Callable] = None, + return_prompt: Optional[bool] = True, + **generation_kwargs ) -> torch.Tensor: r""" Generates model's responses given queries. diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/tuner/ppo/utils.py new file mode 100644 index 00000000..55f67be1 --- /dev/null +++ b/src/llmtuner/tuner/ppo/utils.py @@ -0,0 +1,37 @@ +import torch +from typing import Dict, List, Literal, Optional, Tuple +from trl import AutoModelForCausalLMWithValueHead + +from llmtuner.extras.constants import LAYERNORM_NAMES + + +def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: + if target == "reward": # save default head temporarily + valuehead_state_dict = model.v_head.state_dict() + setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) + setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"]) + + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "{}_head_weight".format(target)), + "summary.bias": getattr(model, "{}_head_bias".format(target)) + }) + + +def cast_layernorm_dtype( + model: AutoModelForCausalLMWithValueHead, + layer_norm_names: List[str] = LAYERNORM_NAMES, + layer_norm_params: Optional[Dict[str, torch.Tensor]] = None +) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: + + layer_norm_state_dict = {} + + for name, param in model.named_parameters(): + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + if layer_norm_params is not None: + param.data = layer_norm_params[name] # restore float32 weights + else: + layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability + param.data = param.data.to(torch.float16) + + return model, layer_norm_state_dict diff --git a/src/train_ppo.py b/src/llmtuner/tuner/ppo/workflow.py similarity index 66% rename from src/train_ppo.py rename to src/llmtuner/tuner/ppo/workflow.py index 1de2a1a5..1f63cdaa 100644 --- a/src/train_ppo.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -1,36 +1,30 @@ -# coding=utf-8 -# Implements parameter-efficient PPO training of fine-tuned models. -# This code is inspired by: +# Inspired by: # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py import math - -from torch.optim import AdamW -from transformers.optimization import get_scheduler from trl import PPOConfig -from transformers import DataCollatorForSeq2Seq -from utils import ( - PPOPeftTrainer, - LogCallback, - load_pretrained, - prepare_args, - prepare_data, - preprocess_data, - plot_loss -) +from torch.optim import AdamW +from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments +from transformers.optimization import get_scheduler + +from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.ploting import plot_loss +from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.ppo.trainer import PPOPeftTrainer -def main(): - - # Prepare pretrained model and dataset - model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo") - dataset = prepare_data(model_args, data_args) - model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo") - dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") - data_collator = DataCollatorForSeq2Seq( - tokenizer=tokenizer, - label_pad_token_id=tokenizer.pad_token_id - ) +def run_ppo( + model_args: ModelArguments, + data_args: DataArguments, + training_args: Seq2SeqTrainingArguments, + finetuning_args: FinetuningArguments +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id) ppo_config = PPOConfig( model_name=model_args.model_name_or_path, @@ -72,12 +66,3 @@ def main(): ppo_trainer.save_state() # must be after save_model if ppo_trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "reward"]) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/src/llmtuner/tuner/pt/__init__.py b/src/llmtuner/tuner/pt/__init__.py new file mode 100644 index 00000000..8ce509db --- /dev/null +++ b/src/llmtuner/tuner/pt/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.pt.workflow import run_pt diff --git a/src/train_pt.py b/src/llmtuner/tuner/pt/workflow.py similarity index 59% rename from src/train_pt.py rename to src/llmtuner/tuner/pt/workflow.py index 7461e251..1837e366 100644 --- a/src/train_pt.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -1,31 +1,28 @@ -# coding=utf-8 -# Implements several parameter-efficient pre-training method. -# This code is inspired by -# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py - +# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py import math -from transformers import DataCollatorForSeq2Seq -from utils.other import IGNORE_INDEX +from typing import Optional, List +from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback -from utils import ( - PeftTrainer, - LogCallback, - load_pretrained, - prepare_args, - prepare_data, - preprocess_data, - plot_loss -) +from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.ploting import plot_loss +from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.core.trainer import PeftTrainer -def main(): - - # Prepare pretrained model and dataset - model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt") - dataset = prepare_data(model_args, data_args) - model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt") - dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt") +def run_pt( + model_args: ModelArguments, + data_args: DataArguments, + training_args: Seq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: Optional[List[TrainerCallback]] = [LogCallback()] +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id @@ -48,7 +45,7 @@ def main(): args=training_args, tokenizer=tokenizer, data_collator=data_collator, - callbacks=[LogCallback()], + callbacks=callbacks, **trainer_kwargs ) @@ -65,21 +62,12 @@ def main(): # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") - try: perplexity = math.exp(metrics["eval_loss"]) except OverflowError: perplexity = float("inf") + metrics["perplexity"] = perplexity trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/src/llmtuner/tuner/rm/__init__.py b/src/llmtuner/tuner/rm/__init__.py new file mode 100644 index 00000000..54d3d943 --- /dev/null +++ b/src/llmtuner/tuner/rm/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.rm.workflow import run_rm diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/tuner/rm/collator.py new file mode 100644 index 00000000..57d6b54b --- /dev/null +++ b/src/llmtuner/tuner/rm/collator.py @@ -0,0 +1,19 @@ +import torch +from typing import Any, Dict, Sequence +from transformers import DataCollatorWithPadding + + +class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): + r""" + Data collator for pairwise data. + """ + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] + return super().__call__(features) diff --git a/src/llmtuner/tuner/rm/metric.py b/src/llmtuner/tuner/rm/metric.py new file mode 100644 index 00000000..db9c9243 --- /dev/null +++ b/src/llmtuner/tuner/rm/metric.py @@ -0,0 +1,7 @@ +import numpy as np +from typing import Dict, Sequence, Tuple, Union + + +def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + preds, _ = eval_preds + return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py new file mode 100644 index 00000000..199fecf4 --- /dev/null +++ b/src/llmtuner/tuner/rm/trainer.py @@ -0,0 +1,38 @@ +import torch +from typing import Dict, List, Optional, Tuple, Union +from transformers.modeling_utils import PreTrainedModel + +from llmtuner.tuner.core.trainer import PeftTrainer + + +class PairwisePeftTrainer(PeftTrainer): + r""" + Inherits PeftTrainer to compute pairwise loss. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.can_return_loss = True # override property to return eval_loss + + def compute_loss( + self, + model: PreTrainedModel, + inputs: Dict[str, torch.Tensor], + return_outputs: Optional[bool] = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + r""" + Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. + + We use score on the EOS token to represent reward of the whole sentence. + + Subclass and override to inject custom behavior. It should not be directly used by external scripts. + + Note that the first element will be removed from the output tuple. + + See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 + """ + batch_size = inputs["input_ids"].size(0) // 2 + _, _, values = model(**inputs) + r_accept, r_reject = values[:, -1].split(batch_size, dim=0) + loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() + return (loss, [loss, r_accept, r_reject]) if return_outputs else loss diff --git a/src/train_rm.py b/src/llmtuner/tuner/rm/workflow.py similarity index 64% rename from src/train_rm.py rename to src/llmtuner/tuner/rm/workflow.py index 3d809758..db81500f 100644 --- a/src/train_rm.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -1,30 +1,28 @@ -# coding=utf-8 -# Implements parameter-efficient training of reward models. -# This code is inspired by: +# Inspired by: # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +from transformers import Seq2SeqTrainingArguments -from utils import ( - PairwiseDataCollatorWithPadding, - PairwisePeftTrainer, - LogCallback, - load_pretrained, - prepare_args, - prepare_data, - preprocess_data, - compute_accuracy, - plot_loss -) +from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.ploting import plot_loss +from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.rm.metric import compute_accuracy +from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding +from llmtuner.tuner.rm.trainer import PairwisePeftTrainer -def main(): - - # Prepare pretrained model and dataset - model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm") - dataset = prepare_data(model_args, data_args) - model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm") - dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm") +def run_rm( + model_args: ModelArguments, + data_args: DataArguments, + training_args: Seq2SeqTrainingArguments, + finetuning_args: FinetuningArguments +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") data_collator = PairwiseDataCollatorWithPadding(tokenizer) training_args.remove_unused_columns = False # important for pairwise dataset @@ -66,12 +64,3 @@ def main(): metrics = trainer.evaluate(metric_key_prefix="eval") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/src/llmtuner/tuner/sft/__init__.py b/src/llmtuner/tuner/sft/__init__.py new file mode 100644 index 00000000..493dd1a7 --- /dev/null +++ b/src/llmtuner/tuner/sft/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.sft.workflow import run_sft diff --git a/src/llmtuner/tuner/sft/metric.py b/src/llmtuner/tuner/sft/metric.py new file mode 100644 index 00000000..3f13f3c7 --- /dev/null +++ b/src/llmtuner/tuner/sft/metric.py @@ -0,0 +1,51 @@ +import numpy as np +from dataclasses import dataclass +from typing import Dict, Sequence, Tuple, Union +from transformers.tokenization_utils import PreTrainedTokenizer + +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction + +from llmtuner.extras.constants import IGNORE_INDEX + + +@dataclass +class ComputeMetrics: + r""" + Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. + """ + + tokenizer: PreTrainedTokenizer + + def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + r""" + Uses the model predictions to compute metrics. + """ + preds, labels = eval_preds + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + + if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: + result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + return {k: float(np.mean(v)) for k, v in score_dict.items()} diff --git a/src/utils/seq2seq.py b/src/llmtuner/tuner/sft/trainer.py similarity index 51% rename from src/utils/seq2seq.py rename to src/llmtuner/tuner/sft/trainer.py index cfa637d7..88075705 100644 --- a/src/utils/seq2seq.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -3,65 +3,17 @@ import json import torch import numpy as np import torch.nn as nn -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - +from typing import Any, Dict, List, Optional, Tuple, Union from transformers.trainer import PredictionOutput -from transformers.tokenization_utils import PreTrainedTokenizer -import jieba -from rouge_chinese import Rouge -from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction - -from .peft_trainer import PeftTrainer - -from .other import get_logger, IGNORE_INDEX +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.logging import get_logger +from llmtuner.tuner.core.trainer import PeftTrainer logger = get_logger(__name__) -@dataclass -class ComputeMetrics: - r""" - Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. - """ - - tokenizer: PreTrainedTokenizer - - def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: - r""" - Uses the model predictions to compute metrics. - """ - preds, labels = eval_preds - score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} - - preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) - labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) - - decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) - decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) - - for pred, label in zip(decoded_preds, decoded_labels): - hypothesis = list(jieba.cut(pred)) - reference = list(jieba.cut(label)) - - if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: - result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} - else: - rouge = Rouge() - scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) - result = scores[0] - - for k, v in result.items(): - score_dict[k].append(round(v["f"] * 100, 4)) - - bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) - score_dict["bleu-4"].append(round(bleu_score * 100, 4)) - - return {k: float(np.mean(v)) for k, v in score_dict.items()} - - class Seq2SeqPeftTrainer(PeftTrainer): r""" Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. @@ -80,7 +32,10 @@ class Seq2SeqPeftTrainer(PeftTrainer): Subclass and override to inject custom behavior. """ prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) - inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1) + if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs + inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1) + else: + inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1) loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) @@ -89,8 +44,8 @@ class Seq2SeqPeftTrainer(PeftTrainer): return (loss, generated_tokens, labels) def save_predictions( - self, - predict_results: PredictionOutput + self, + predict_results: PredictionOutput ) -> None: r""" Saves model predictions to `output_dir`. diff --git a/src/train_sft.py b/src/llmtuner/tuner/sft/workflow.py similarity index 69% rename from src/train_sft.py rename to src/llmtuner/tuner/sft/workflow.py index 49c53cb8..08889796 100644 --- a/src/train_sft.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -1,31 +1,29 @@ -# coding=utf-8 -# Implements several parameter-efficient supervised fine-tuning method. -# This code is inspired by -# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py +# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py + +from typing import Optional, List +from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback + +from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.misc import get_logits_processor +from llmtuner.extras.ploting import plot_loss +from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.sft.metric import ComputeMetrics +from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer -from transformers import DataCollatorForSeq2Seq -from utils.other import IGNORE_INDEX -from utils import ( - Seq2SeqPeftTrainer, - ComputeMetrics, - LogCallback, - load_pretrained, - prepare_args, - prepare_data, - preprocess_data, - get_logits_processor, - plot_loss -) - - -def main(): - - # Prepare pretrained model and dataset - model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft") - dataset = prepare_data(model_args, data_args) - model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") - dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") +def run_sft( + model_args: ModelArguments, + data_args: DataArguments, + training_args: Seq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: Optional[List[TrainerCallback]] = [LogCallback()] +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id @@ -54,7 +52,7 @@ def main(): args=training_args, tokenizer=tokenizer, data_collator=data_collator, - callbacks=[LogCallback()], + callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, **trainer_kwargs ) @@ -94,12 +92,3 @@ def main(): trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(predict_results) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/src/train_bash.py b/src/train_bash.py new file mode 100644 index 00000000..291c3cf0 --- /dev/null +++ b/src/train_bash.py @@ -0,0 +1,23 @@ +from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo + + +def main(): + model_args, data_args, training_args, finetuning_args, general_args = get_train_args() + + if general_args.stage == "pt": + run_pt(model_args, data_args, training_args, finetuning_args) + elif general_args.stage == "sft": + run_sft(model_args, data_args, training_args, finetuning_args) + elif general_args.stage == "rm": + run_rm(model_args, data_args, training_args, finetuning_args) + elif general_args.stage == "ppo": + run_ppo(model_args, data_args, training_args, finetuning_args) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index 977b58c3..00000000 --- a/src/utils/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .common import ( - load_pretrained, - prepare_args, - prepare_infer_args, - prepare_data, - preprocess_data -) - -from .peft_trainer import PeftTrainer, LogCallback - -from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer -from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy -from .ppo import PPOPeftTrainer - -from .template import Template - -from .other import get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py deleted file mode 100644 index 7f2663fd..00000000 --- a/src/utils/common.py +++ /dev/null @@ -1,619 +0,0 @@ -import os -import sys -import torch -import hashlib -from itertools import chain -from typing import List, Literal, Optional, Tuple - -import transformers -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - HfArgumentParser, - Seq2SeqTrainingArguments, - BitsAndBytesConfig -) -from transformers.utils import check_min_version -from transformers.utils.versions import require_version -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer - -import datasets -from datasets import Dataset, concatenate_datasets, load_dataset - -from peft import ( - PeftModel, - TaskType, - LoraConfig, - get_peft_model -) - -from peft.utils import CONFIG_NAME, WEIGHTS_NAME - -from trl import AutoModelForCausalLMWithValueHead - -from .config import ( - ModelArguments, - DataTrainingArguments, - FinetuningArguments, - GeneratingArguments -) - -from .template import Template - -from .other import ( - get_logger, - load_trainable_params, - load_valuehead_params, - print_trainable_params, - prepare_model_for_training, - IGNORE_INDEX -) - -check_min_version("4.29.1") -require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") -require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") -require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") -require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4") - - -logger = get_logger(__name__) - - -def _init_adapter( - model: PreTrainedModel, - model_args: ModelArguments, - finetuning_args: FinetuningArguments, - is_trainable: bool, - is_mergeable: bool -) -> PreTrainedModel: - r""" - Initializes the adapters. - - Support full-parameter, freeze and LoRA training. - - Note that the trainable parameters must be cast to float32. - """ - - if finetuning_args.finetuning_type == "none" and is_trainable: - raise ValueError("You cannot use finetuning_type=none while training.") - - if finetuning_args.finetuning_type == "full": - logger.info("Fine-tuning method: Full") - model = model.float() - - if finetuning_args.finetuning_type == "freeze": - logger.info("Fine-tuning method: Freeze") - - for name, param in model.named_parameters(): - if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): - param.requires_grad_(False) - else: - param.data = param.data.to(torch.float32) - - if model_args.checkpoint_dir is not None: - assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." - - if finetuning_args.finetuning_type == "lora": - logger.info("Fine-tuning method: LoRA") - latest_checkpoint = None - - if model_args.checkpoint_dir is not None: - assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \ - "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]) - assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ - "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." - - if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights - checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] - else: - checkpoints_to_merge = model_args.checkpoint_dir - - for checkpoint in checkpoints_to_merge: - model = PeftModel.from_pretrained(model, checkpoint) - model = model.merge_and_unload() - - if len(checkpoints_to_merge) > 0: - logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) - - if latest_checkpoint is not None: # resume lora training or quantized inference - model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) - - if is_trainable and latest_checkpoint is None: # create new lora weights while training - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=finetuning_args.lora_rank, - lora_alpha=finetuning_args.lora_alpha, - lora_dropout=finetuning_args.lora_dropout, - target_modules=finetuning_args.lora_target - ) - model = get_peft_model(model, lora_config) - - if model_args.checkpoint_dir is not None: - logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) - - return model - - -def load_pretrained( - model_args: ModelArguments, - finetuning_args: FinetuningArguments, - is_trainable: Optional[bool] = False, - stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" -) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: - r""" - Loads pretrained model and tokenizer. - - Support both training and inference. - """ - if (not is_trainable) and model_args.checkpoint_dir is None: - logger.warning("Checkpoint is not found at evaluation, load the original model.") - finetuning_args = FinetuningArguments(finetuning_type="none") - - assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ - "RM and PPO training can only be performed with the LoRA method." - - config_kwargs = { - "trust_remote_code": True, - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - use_fast=model_args.use_fast_tokenizer, - padding_side=model_args.padding_side, - **config_kwargs - ) - if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version) - tokenizer.pad_token_id = 0 # set as the token - - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) - is_mergeable = True - - # Quantization configurations (using bitsandbytes library). - if model_args.quantization_bit is not None: - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["load_in_8bit"] = True - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_threshold=6.0 - ) - - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1") - require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3") - require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") - config_kwargs["load_in_4bit"] = True - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type - ) - - is_mergeable = False - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - - if not is_trainable: # `device_map=auto` should be used for inference only - config_kwargs["device_map"] = "auto" - - if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": - model_to_load = model_args.checkpoint_dir[0] - else: - model_to_load = model_args.model_name_or_path - - # Load and prepare pretrained models (without valuehead). - model = AutoModelForCausalLM.from_pretrained( - model_to_load, - config=config, - torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, - low_cpu_mem_usage=True, - **config_kwargs - ) - - # Register auto class to save the custom code files. - if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map: - config.__class__.register_for_auto_class() - if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: - tokenizer.__class__.register_for_auto_class() - if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map: - model.__class__.register_for_auto_class() - - # Initialize adapters - model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model - model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) - - if stage == "rm" or stage == "ppo": # add value head - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - - if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model - logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") - if load_valuehead_params(model, model_args.checkpoint_dir[-1]): - model.v_head.load_state_dict({ - "summary.weight": getattr(model, "reward_head_weight"), - "summary.bias": getattr(model, "reward_head_bias") - }) - - if stage == "ppo": # load reward model - assert is_trainable, "PPO stage cannot be performed at evaluation." - assert model_args.reward_model is not None, "Reward model is necessary for PPO training." - logger.info("Load reward model from {}".format(model_args.reward_model)) - model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) - assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." - - if not is_trainable: - model.requires_grad_(False) # fix all model params - model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 - - print_trainable_params(model) - - return model, tokenizer - - -def prepare_args( - stage: Literal["pt", "sft", "rm", "ppo"] -) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: - - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) - - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. - model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() - - # Setup logging - if training_args.should_log: - # The default of training_args.log_level is passive, so we set log level at info here to have that default. - transformers.utils.logging.set_verbosity_info() - - log_level = training_args.get_process_log_level() - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) - data_args.init_for_training() - - assert stage == "sft" or (not training_args.predict_with_generate), \ - "`predict_with_generate` cannot be set as True at PT, RM and PPO stages." - - assert not (training_args.do_train and training_args.predict_with_generate), \ - "`predict_with_generate` cannot be set as True while training." - - assert (not training_args.do_predict) or training_args.predict_with_generate, \ - "Please enable `predict_with_generate` to save model predictions." - - assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ - "Quantization is only compatible with the LoRA method." - - if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora": - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - else: - assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ - "Quantized model only accepts a single checkpoint." - - if model_args.quantization_bit is not None and (not training_args.do_train): - logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") - - if training_args.do_train and (not training_args.fp16): - logger.warning("We recommend enable fp16 mixed precision training.") - - if data_args.prompt_template == "default": - logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") - - if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: - logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") - training_args.ddp_find_unused_parameters = False - - training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning - - if model_args.quantization_bit is not None: - if training_args.fp16: - model_args.compute_dtype = torch.float16 - elif training_args.bf16: - model_args.compute_dtype = torch.bfloat16 - else: - model_args.compute_dtype = torch.float32 - - # Log on each process the small summary: - logger.info( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" - + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Set seed before initializing model. - transformers.set_seed(training_args.seed) - - return model_args, data_args, training_args, finetuning_args - - -def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]: - - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments)) - - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. - model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() - - assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ - "Quantization is only compatible with the LoRA method." - - if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora": - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - else: - assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ - "Quantized model only accepts a single checkpoint." - - if data_args.prompt_template == "default": - logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") - - return model_args, data_args, finetuning_args, generating_args - - -def prepare_data( - model_args: ModelArguments, - data_args: DataTrainingArguments -) -> Dataset: - - def checksum(file_path, hash): - with open(file_path, "rb") as datafile: - binary_data = datafile.read() - sha1 = hashlib.sha1(binary_data).hexdigest() - if sha1 != hash: - logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) - - ext2type = { - "csv": "csv", - "json": "json", - "jsonl": "json", - "txt": "text" - } - - max_samples = data_args.max_samples - all_datasets: List[Dataset] = [] # support multiple datasets - - for dataset_attr in data_args.dataset_list: - - logger.info("Loading dataset {}...".format(dataset_attr)) - - if dataset_attr.load_from == "hf_hub": - data_path = dataset_attr.dataset_name - data_files = None - elif dataset_attr.load_from == "script": - data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) - data_files = None - elif dataset_attr.load_from == "file": - data_path = None - data_files: List[str] = [] - - if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): - for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): - data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) - - if data_path is None: - data_path = ext2type.get(data_files[0].split(".")[-1], None) - else: - assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." - elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): - data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) - data_path = ext2type.get(data_files[0].split(".")[-1], None) - else: - raise ValueError("File not found.") - - assert data_path, "File extension must be txt, csv, json or jsonl." - - if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None: - checksum(data_files[0], dataset_attr.dataset_sha1) - else: - logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.") - else: - raise NotImplementedError - - raw_datasets = load_dataset( - data_path, - data_files=data_files, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None - ) - dataset = raw_datasets[data_args.split] - - if max_samples is not None: - max_samples_temp = min(len(dataset), max_samples) - dataset = dataset.select(range(max_samples_temp)) - - dummy_data = [None] * len(dataset) - prefix_data = [dataset_attr.source_prefix] * len(dataset) - for column_name, target_name in [ - ("prompt_column", "prompt"), - ("query_column", "query"), - ("response_column", "response"), - ("history_column", "history") - ]: # every dataset will have 4 columns same as each other - if getattr(dataset_attr, column_name) != target_name: - if getattr(dataset_attr, column_name): - dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) - else: # None or empty string - dataset = dataset.add_column(target_name, dummy_data) - dataset = dataset.add_column("prefix", prefix_data) - all_datasets.append(dataset) - - if len(data_args.dataset_list) == 1: - all_datasets = all_datasets[0] - else: - all_datasets = concatenate_datasets(all_datasets) - - return all_datasets - - -def preprocess_data( - dataset: Dataset, - tokenizer: PreTrainedTokenizer, - data_args: DataTrainingArguments, - training_args: Seq2SeqTrainingArguments, - stage: Literal["pt", "sft", "rm", "ppo"] -) -> Dataset: - - column_names = list(dataset.column_names) - prompt_template = Template(data_args.prompt_template) - - # support question with a single answer or multiple answers - def get_dialog(examples): - for i in range(len(examples["prompt"])): - if examples["prompt"][i] and examples["response"][i]: - query, answer = examples["prompt"][i], examples["response"][i] - query = query + "\n" + examples["query"][i] if examples["query"][i] else query - prefix = examples["prefix"][i] if examples["prefix"][i] else "" - dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) - yield dialog - - def preprocess_pretrain_dataset(examples): - # build grouped texts with format ` X1 X2 X3 ...` (without ) - text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] - concatenated_ids = list(chain(*text_ids)) - total_length = len(concatenated_ids) - block_size = data_args.max_source_length - 1 - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of max_source_length - result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] - for i in range(0, total_length, block_size)] - return { - "input_ids": result, - "labels": result.copy() - } - - def preprocess_supervised_dataset(examples): - # build inputs with format ` X Y ` and labels with format ` ... Y ` - # for input with history, we build multiple input-label pairs just like: - # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 - model_inputs = {"input_ids": [], "labels": []} - max_length = data_args.max_source_length + data_args.max_target_length - - for dialog in get_dialog(examples): - input_ids, labels = [], [] - - for i in range(len(dialog) // 2): - source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) - target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) - - if len(source_ids) > data_args.max_source_length: - source_ids = source_ids[:data_args.max_source_length] - if len(target_ids) > data_args.max_target_length - 1: # eos token - target_ids = target_ids[:data_args.max_target_length - 1] - - if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: - break - - input_ids += source_ids + target_ids + [tokenizer.eos_token_id] - labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] - - model_inputs["input_ids"].append(input_ids) - model_inputs["labels"].append(labels) - - return model_inputs - - def preprocess_unsupervised_dataset(examples): - # build inputs with format ` X` and labels with format ` Y` - model_inputs = {"input_ids": [], "labels": []} - - for dialog in get_dialog(examples): - prompt, answer = "".join(dialog[:-1]), dialog[-1] - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - target_ids = tokenizer.encode(text=answer, add_special_tokens=True) - - if len(source_ids) > data_args.max_source_length: - source_ids = source_ids[:data_args.max_source_length] - if len(target_ids) > data_args.max_target_length: - target_ids = target_ids[:data_args.max_target_length] - - model_inputs["input_ids"].append(source_ids) - model_inputs["labels"].append(target_ids) - - return model_inputs - - def preprocess_pairwise_dataset(examples): - # build input pairs with format ` X Y1 ` and ` X Y2 ` - model_inputs = {"accept_ids": [], "reject_ids": []} - for dialog in get_dialog(examples): - prompt, answer = "".join(dialog[:-1]), dialog[-1] - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) - reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) - - if len(source_ids) > data_args.max_source_length: - source_ids = source_ids[:data_args.max_source_length] - if len(accept_ids) > data_args.max_target_length - 1: # eos token - accept_ids = accept_ids[:data_args.max_target_length - 1] - if len(reject_ids) > data_args.max_target_length - 1: # eos token - reject_ids = reject_ids[:data_args.max_target_length - 1] - - accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id] - reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id] - - model_inputs["accept_ids"].append(accept_ids) - model_inputs["reject_ids"].append(reject_ids) - return model_inputs - - def print_supervised_dataset_example(example): - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format( - tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]], - skip_special_tokens=False) - )) - - def print_pairwise_dataset_example(example): - print("accept_ids:\n{}".format(example["accept_ids"])) - print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) - print("reject_ids:\n{}".format(example["reject_ids"])) - print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) - - def print_unsupervised_dataset_example(example): - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - - if stage == "pt": - preprocess_function = preprocess_pretrain_dataset - elif stage == "sft": - preprocess_function = preprocess_unsupervised_dataset \ - if training_args.predict_with_generate else preprocess_supervised_dataset - elif stage == "rm": - preprocess_function = preprocess_pairwise_dataset - elif stage == "ppo": - preprocess_function = preprocess_unsupervised_dataset - - with training_args.main_process_first(desc="dataset map pre-processing"): - dataset = dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset" - ) - - if stage == "pt": - print_unsupervised_dataset_example(dataset[0]) - elif stage == "sft": - print_supervised_dataset_example(dataset[0]) - elif stage == "rm": - print_pairwise_dataset_example(dataset[0]) - elif stage == "ppo": - print_unsupervised_dataset_example(dataset[0]) - - return dataset diff --git a/src/utils/config.py b/src/utils/config.py deleted file mode 100644 index a340e69b..00000000 --- a/src/utils/config.py +++ /dev/null @@ -1,312 +0,0 @@ -import os -import json -import torch -from typing import Any, Dict, List, Literal, Optional -from dataclasses import asdict, dataclass, field - - -@dataclass -class DatasetAttr: - - load_from: str - dataset_name: Optional[str] = None - dataset_sha1: Optional[str] = None - source_prefix: Optional[str] = None - - def __repr__(self) -> str: - return self.dataset_name - - def __post_init__(self): - self.prompt_column = "instruction" - self.query_column = "input" - self.response_column = "output" - self.history_column = None - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune. - """ - model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} - ) - use_fast_tokenizer: Optional[bool] = field( - default=False, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} - ) - use_auth_token: Optional[bool] = field( - default=False, - metadata={"help": "Will use the token generated when running `huggingface-cli login`."} - ) - model_revision: Optional[str] = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} - ) - padding_side: Optional[Literal["left", "right"]] = field( - default="left", - metadata={"help": "The side on which the model should have padding applied."} - ) - quantization_bit: Optional[int] = field( - default=None, - metadata={"help": "The number of bits to quantize the model."} - ) - quantization_type: Optional[Literal["fp4", "nf4"]] = field( - default="nf4", - metadata={"help": "Quantization data type to use in int4 training."} - ) - double_quantization: Optional[bool] = field( - default=True, - metadata={"help": "Whether to use double quantization in int4 training or not."} - ) - compute_dtype: Optional[torch.dtype] = field( - default=None, - metadata={"help": "Used in quantization configs. Do not specify this argument manually."} - ) - checkpoint_dir: Optional[str] = field( - default=None, - metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} - ) - reward_model: Optional[str] = field( - default=None, - metadata={"help": "Path to the directory containing the checkpoints of the reward model."} - ) - resume_lora_training: Optional[bool] = field( - default=True, - metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} - ) - plot_loss: Optional[bool] = field( - default=False, - metadata={"help": "Whether to plot the training loss after fine-tuning or not."} - ) - - def __post_init__(self): - if self.checkpoint_dir is not None: # support merging multiple lora weights - self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] - - if self.quantization_bit is not None: - assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and evaluation. - """ - dataset: Optional[str] = field( - default="alpaca_zh", - metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} - ) - dataset_dir: Optional[str] = field( - default="data", - metadata={"help": "The name of the folder containing datasets."} - ) - split: Optional[str] = field( - default="train", - metadata={"help": "Which dataset split to use for training and evaluation."} - ) - overwrite_cache: Optional[bool] = field( - default=False, - metadata={"help": "Overwrite the cached training and evaluation sets."} - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."} - ) - max_source_length: Optional[int] = field( - default=512, - metadata={"help": "The maximum total input sequence length after tokenization."} - ) - max_target_length: Optional[int] = field( - default=512, - metadata={"help": "The maximum total output sequence length after tokenization."} - ) - max_samples: Optional[int] = field( - default=None, - metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} - ) - eval_num_beams: Optional[int] = field( - default=None, - metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} - ) - ignore_pad_token_for_loss: Optional[bool] = field( - default=True, - metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} - ) - source_prefix: Optional[str] = field( - default=None, - metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} - ) - dev_ratio: Optional[float] = field( - default=0, - metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} - ) - prompt_template: Optional[str] = field( - default="default", - metadata={"help": "Which template to use for constructing prompts in training and inference."} - ) - - def init_for_training(self): # support mixing multiple datasets - dataset_names = [ds.strip() for ds in self.dataset.split(",")] - with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: - dataset_info = json.load(f) - - if self.source_prefix is not None: - prefix_list = self.source_prefix.split("|") - prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list - assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1." - else: - prefix_list = [None] * len(dataset_names) - - self.dataset_list: List[DatasetAttr] = [] - for i, name in enumerate(dataset_names): - if name not in dataset_info: - raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) - - if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) - elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) - else: - dataset_attr = DatasetAttr( - "file", - dataset_name=dataset_info[name]["file_name"], - dataset_sha1=dataset_info[name].get("file_sha1", None) - ) - - dataset_attr.source_prefix = prefix_list[i] - - if "columns" in dataset_info[name]: - dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) - dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) - dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) - dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) - - self.dataset_list.append(dataset_attr) - - -@dataclass -class FinetuningArguments: - """ - Arguments pertaining to which techniques we are going to fine-tuning with. - """ - finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( - default="lora", - metadata={"help": "Which fine-tuning method to use."} - ) - num_hidden_layers: Optional[int] = field( - default=32, - metadata={"help": "Number of decoder blocks in the model. \ - LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ - BLOOM choices: [\"24\", \"30\", \"70\"], \ - Falcon choices: [\"32\", \"60\"], \ - Baichuan choices: [\"32\"]"} - ) - num_layer_trainable: Optional[int] = field( - default=3, - metadata={"help": "Number of trainable layers for Freeze fine-tuning."} - ) - name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( - default="mlp", - metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ - LLaMA choices: [\"mlp\", \"self_attn\"], \ - BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ - Baichuan choices: [\"mlp\", \"self_attn\"]"} - ) - lora_rank: Optional[int] = field( - default=8, - metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} - ) - lora_alpha: Optional[float] = field( - default=32.0, - metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} - ) - lora_dropout: Optional[float] = field( - default=0.1, - metadata={"help": "Dropout rate for the LoRA fine-tuning."} - ) - lora_target: Optional[str] = field( - default="q_proj,v_proj", - metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ - LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ - BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ - Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} - ) - - def __post_init__(self): - if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA - self.lora_target = [target.strip() for target in self.lora_target.split(",")] - - if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 - trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)] - else: # fine-tuning the first n layers if num_layer_trainable < 0 - trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] - - self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] - - assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." - - def save_to_json(self, json_path: str): - """Saves the content of this instance in JSON format inside `json_path`.""" - json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" - with open(json_path, "w", encoding="utf-8") as f: - f.write(json_string) - - @classmethod - def load_from_json(cls, json_path: str): - """Creates an instance from the content of `json_path`.""" - with open(json_path, "r", encoding="utf-8") as f: - text = f.read() - return cls(**json.loads(text)) - - -@dataclass -class GeneratingArguments: - """ - Arguments pertaining to specify the decoding parameters. - """ - do_sample: Optional[bool] = field( - default=True, - metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} - ) - temperature: Optional[float] = field( - default=0.95, - metadata={"help": "The value used to modulate the next token probabilities."} - ) - top_p: Optional[float] = field( - default=0.7, - metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} - ) - top_k: Optional[int] = field( - default=50, - metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} - ) - num_beams: Optional[int] = field( - default=1, - metadata={"help": "Number of beams for beam search. 1 means no beam search."} - ) - max_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} - ) - max_new_tokens: Optional[int] = field( - default=512, - metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} - ) - repetition_penalty: Optional[float] = field( - default=1.0, - metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} - ) - length_penalty: Optional[float] = field( - default=1.0, - metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} - ) - - def to_dict(self) -> Dict[str, Any]: - args = asdict(self) - if args.get("max_new_tokens", None): - args.pop("max_length", None) - return args diff --git a/src/utils/other.py b/src/utils/other.py deleted file mode 100644 index ce780ab9..00000000 --- a/src/utils/other.py +++ /dev/null @@ -1,197 +0,0 @@ -import os -import sys -import json -import torch -import logging -from typing import Dict, List, Optional - -from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint -from transformers.generation.utils import LogitsProcessorList -from transformers.generation.logits_process import LogitsProcessor - - -IGNORE_INDEX = -100 -VALUE_HEAD_FILE_NAME = "value_head.bin" -FINETUNING_ARGS_NAME = "finetuning_args.json" - - -def get_logger(name: str) -> logging.Logger: - return logging.getLogger(name) - - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - handlers=[logging.StreamHandler(sys.stdout)] -) - - -logger = get_logger(__name__) - - -class AverageMeter: - r""" - Computes and stores the average and current value. - """ - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -# Avoid runtime error in model.generate(do_sample=True). -class InvalidScoreLogitsProcessor(LogitsProcessor): - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 0] = 1.0 - return scores - - -def get_logits_processor() -> LogitsProcessorList: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - return logits_processor - - -# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 -# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 -def prepare_model_for_training( - model: PreTrainedModel, - finetuning_type: str, - output_embedding_layer_name: Optional[str] = "lm_head", - use_gradient_checkpointing: Optional[bool] = True, - layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings -) -> PreTrainedModel: - - for name, param in model.named_parameters(): - if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): - param.data = param.data.to(torch.float32) - - if use_gradient_checkpointing: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - model.gradient_checkpointing_enable() - model.config.use_cache = False # turn off when gradient checkpointing is enabled - - if finetuning_type != "full" and hasattr(model, output_embedding_layer_name): - output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name) - input_dtype = output_embedding_layer.weight.dtype - - class CastOutputToFloat(torch.nn.Sequential): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x.to(input_dtype)).to(torch.float32) - - setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) - - return model - - -def print_trainable_params(model: torch.nn.Module) -> None: - trainable_params, all_param = 0, 0 - for param in model.parameters(): - num_params = param.numel() - # if using DS Zero 3 and the weights are initialized empty - if num_params == 0 and hasattr(param, "ds_numel"): - num_params = param.ds_numel - all_param += num_params - if param.requires_grad: - trainable_params += num_params - print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( - trainable_params, all_param, 100 * trainable_params / all_param)) - - -def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters - state_dict = model.state_dict() - filtered_state_dict = {} - - for k, v in model.named_parameters(): - if v.requires_grad: - filtered_state_dict[k] = state_dict[k].cpu().clone().detach() - - return filtered_state_dict - - -def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: - weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) - if os.path.exists(weights_file): - model_state_dict = torch.load(weights_file, map_location="cpu") - model.load_state_dict(model_state_dict, strict=False) # skip missing keys - elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)): - load_sharded_checkpoint(model, checkpoint_dir, strict=False) - else: - logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) - return False - return True - - -def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: - valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) - if not os.path.exists(valuehead_file): - logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) - return False - valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") - model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) - model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) - model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) - model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) - return True - - -def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: - r""" - EMA implementation according to TensorBoard. - """ - last = scalars[0] - smoothed = list() - for next_val in scalars: - smoothed_val = last * weight + (1 - weight) * next_val - smoothed.append(smoothed_val) - last = smoothed_val - return smoothed - - -def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: - import matplotlib.pyplot as plt - with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: - data = json.load(f) - - for key in keys: - steps, metrics = [], [] - for i in range(len(data["log_history"])): - if key in data["log_history"][i]: - steps.append(data["log_history"][i]["step"]) - metrics.append(data["log_history"][i][key]) - - if len(metrics) == 0: - logger.warning(f"No metric {key} to plot.") - continue - - plt.figure() - plt.plot(steps, metrics, alpha=0.4, label="original") - plt.plot(steps, smooth(metrics), label="smoothed") - plt.title("training {} of {}".format(key, save_dictionary)) - plt.xlabel("step") - plt.ylabel(key) - plt.legend() - plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) - print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py deleted file mode 100644 index bdffc749..00000000 --- a/src/utils/pairwise.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import numpy as np -from typing import Dict, Sequence, Tuple, Union - -from transformers import DataCollatorWithPadding - -from .peft_trainer import PeftTrainer - -from .other import get_logger - -logger = get_logger(__name__) - - -def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: - preds, _ = eval_preds - return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} - - -class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): - r""" - Data collator for pairwise data. - """ - - def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]: - r""" - Pads batched data to the longest sequence in the batch. - - We generate 2 * n examples where the first n examples represent chosen examples and - the last n examples represent rejected examples. - """ - features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] - return super().__call__(features) - - -class PairwisePeftTrainer(PeftTrainer): - r""" - Inherits PeftTrainer to compute pairwise loss. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.can_return_loss = True # override property to return eval_loss - - def compute_loss(self, model, inputs, return_outputs=False): - r""" - Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. - - We use score on the EOS token to represent reward of the whole sentence. - - Subclass and override to inject custom behavior. It should not be directly used by external scripts. - - Note that the first element will be removed from the output tuple. - - See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 - """ - batch_size = inputs["input_ids"].size(0) // 2 - _, _, values = model(**inputs) - r_accept, r_reject = values[:, -1].split(batch_size, dim=0) - loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() - return (loss, [loss, r_accept, r_reject]) if return_outputs else loss diff --git a/src/web_demo.py b/src/web_demo.py index 9fcd906d..c60e4138 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -2,83 +2,26 @@ # Implements user interface in browser for fine-tuned models. # Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint - -import mdtex2html import gradio as gr - from threading import Thread -from utils import ( - Template, - load_pretrained, - prepare_infer_args, - get_logits_processor -) - from transformers import TextIteratorStreamer from transformers.utils.versions import require_version +from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor + require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") -model_args, data_args, finetuning_args, generating_args = prepare_infer_args() -model, tokenizer = load_pretrained(model_args, finetuning_args) +model_args, data_args, finetuning_args, generating_args = get_infer_args() +model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) source_prefix = data_args.source_prefix if data_args.source_prefix else "" -def postprocess(self, y): - r""" - Overrides Chatbot.postprocess - """ - if y is None: - return [] - for i, (message, response) in enumerate(y): - y[i] = ( - None if message is None else mdtex2html.convert((message)), - None if response is None else mdtex2html.convert(response), - ) - return y - - -gr.Chatbot.postprocess = postprocess - - -def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT - lines = text.split("\n") - lines = [line for line in lines if line != ""] - count = 0 - for i, line in enumerate(lines): - if "```" in line: - count += 1 - items = line.split("`") - if count % 2 == 1: - lines[i] = "
".format(items[-1])
-            else:
-                lines[i] = "
" - else: - if i > 0: - if count % 2 == 1: - line = line.replace("`", "\`") - line = line.replace("<", "<") - line = line.replace(">", ">") - line = line.replace(" ", " ") - line = line.replace("*", "*") - line = line.replace("_", "_") - line = line.replace("-", "-") - line = line.replace(".", ".") - line = line.replace("!", "!") - line = line.replace("(", "(") - line = line.replace(")", ")") - line = line.replace("$", "$") - lines[i] = "
" + line - text = "".join(lines) - return text - - def predict(query, chatbot, max_new_tokens, top_p, temperature, history): - chatbot.append((parse_text(query), "")) + chatbot.append((query, "")) input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) @@ -102,7 +45,7 @@ def predict(query, chatbot, max_new_tokens, top_p, temperature, history): for new_text in streamer: response += new_text new_history = history + [(query, response)] - chatbot[-1] = (parse_text(query), parse_text(response)) + chatbot[-1] = (query, response) yield chatbot, new_history