modity code structure
This commit is contained in:
parent
2a0f1f8398
commit
f751376613
19
README.md
19
README.md
|
@ -95,7 +95,7 @@ huggingface-cli login
|
|||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- jieba, rouge-chinese and nltk (used at evaluation)
|
||||
- gradio and mdtex2html (used in web_demo.py)
|
||||
- gradio and matplotlib (used in web_demo.py)
|
||||
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
||||
|
||||
And **powerful GPUs**!
|
||||
|
@ -137,7 +137,8 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \
|
|||
### (Continually) Pre-Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
|
@ -158,7 +159,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
|||
### Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_en \
|
||||
|
@ -179,7 +181,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
|||
### Reward Model Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_en \
|
||||
|
@ -199,7 +202,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
|||
### PPO Training (RLHF)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_en \
|
||||
|
@ -222,7 +226,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
|||
|
||||
```bash
|
||||
accelerate config # configure the environment
|
||||
accelerate launch src/train_XX.py # arguments (same as above)
|
||||
accelerate launch src/train_bash.py # arguments (same as above)
|
||||
```
|
||||
|
||||
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
|
||||
|
@ -256,7 +260,8 @@ use_cpu: false
|
|||
### Evaluation (BLEU and ROUGE_CHINESE)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_en \
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
|
@ -8,8 +8,9 @@ sentencepiece
|
|||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
gradio
|
||||
mdtex2html
|
||||
gradio>=3.36.0
|
||||
uvicorn
|
||||
pydantic==1.10.7
|
||||
fastapi
|
||||
sse-starlette
|
||||
matplotlib
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
import re
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
def get_version():
|
||||
with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
|
||||
version, = re.findall(pattern, file_content)
|
||||
return version
|
||||
|
||||
|
||||
def get_requires():
|
||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
setup(
|
||||
name="llmtuner",
|
||||
version=get_version(),
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
description="Easy-to-use fine-tuning framework using PEFT",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||
license="Apache 2.0 License",
|
||||
url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=get_requires(),
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
218
src/api_demo.py
218
src/api_demo.py
|
@ -4,225 +4,11 @@
|
|||
# Visit http://localhost:8000/docs for document.
|
||||
|
||||
|
||||
import time
|
||||
import torch
|
||||
import uvicorn
|
||||
from threading import Thread
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from transformers import TextIteratorStreamer
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from utils import (
|
||||
Template,
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = []
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Optional[str] = "list"
|
||||
data: Optional[List[ModelCard]] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
global model_args
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer, source_prefix, generating_args
|
||||
|
||||
if request.messages[-1].role != "user":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
prefix = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = source_prefix
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs.update({
|
||||
"input_ids": inputs["input_ids"],
|
||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||
"logits_processor": get_logits_processor()
|
||||
})
|
||||
|
||||
if request.max_tokens:
|
||||
gen_kwargs.pop("max_length", None)
|
||||
gen_kwargs["max_new_tokens"] = request.max_tokens
|
||||
|
||||
if request.stream:
|
||||
generate = predict(gen_kwargs, request.model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
generation_output = model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=len(inputs["input_ids"][0]),
|
||||
completion_tokens=len(outputs),
|
||||
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
finish_reason="stop"
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||
|
||||
|
||||
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||
global model, tokenizer
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in streamer:
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
from llmtuner import create_app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
app = create_app()
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
|
|
@ -2,21 +2,15 @@
|
|||
# Implements stream chat in command line for fine-tuned models.
|
||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
from utils import (
|
||||
Template,
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
)
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
|
|
@ -2,14 +2,12 @@
|
|||
# Exports the fine-tuned model.
|
||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
|
||||
from utils import load_pretrained, prepare_args
|
||||
from llmtuner import get_train_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
model_args, _, training_args, finetuning_args, _ = get_train_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from llmtuner.api import create_app
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
|
||||
|
||||
|
||||
__version__ = "0.0.9"
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.api.app import create_app
|
|
@ -0,0 +1,152 @@
|
|||
import uvicorn
|
||||
from threading import Thread
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from transformers import TextIteratorStreamer
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import Any, Dict
|
||||
|
||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.extras.misc import get_logits_processor, torch_gc
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.api.protocol import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def create_app():
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
global model_args
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if request.messages[-1].role != "user":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
prefix = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = source_prefix
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs.update({
|
||||
"input_ids": inputs["input_ids"],
|
||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||
"logits_processor": get_logits_processor()
|
||||
})
|
||||
|
||||
if request.max_tokens:
|
||||
gen_kwargs.pop("max_length", None)
|
||||
gen_kwargs["max_new_tokens"] = request.max_tokens
|
||||
|
||||
if request.stream:
|
||||
generate = predict(gen_kwargs, request.model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
generation_output = model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=len(inputs["input_ids"][0]),
|
||||
completion_tokens=len(outputs),
|
||||
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
finish_reason="stop"
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||
|
||||
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in streamer:
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
|
@ -0,0 +1,73 @@
|
|||
import time
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = []
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Optional[str] = "list"
|
||||
data: Optional[List[ModelCard]] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
|
@ -0,0 +1,2 @@
|
|||
from llmtuner.dsets.loader import get_dataset
|
||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
|
@ -0,0 +1,63 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.start_time = time.time()
|
||||
self.tracker = {}
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "loss" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
|
@ -0,0 +1,106 @@
|
|||
import os
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
||||
if sha1 != hash:
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
ext2type = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path = None
|
||||
data_files: List[str] = []
|
||||
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
|
||||
if data_path is None:
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||
|
||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
data_path,
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
dataset = raw_datasets[data_args.split]
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
("query_column", "query"),
|
||||
("response_column", "response"),
|
||||
("history_column", "history")
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
if getattr(dataset_attr, column_name) != target_name:
|
||||
if getattr(dataset_attr, column_name):
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
all_datasets = all_datasets[0]
|
||||
else:
|
||||
all_datasets = concatenate_datasets(all_datasets)
|
||||
|
||||
return all_datasets
|
|
@ -0,0 +1,172 @@
|
|||
from typing import Literal
|
||||
from itertools import chain
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def get_dialog(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||
yield dialog
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
block_size = data_args.max_source_length - 1
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
||||
for i in range(0, total_length, block_size)]
|
||||
return {
|
||||
"input_ids": result,
|
||||
"labels": result.copy()
|
||||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
||||
break
|
||||
|
||||
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["labels"].append(target_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||
skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(dataset[0])
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
|
||||
return dataset
|
|
@ -0,0 +1,72 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.tracker = {}
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "current_steps" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"eval_loss": state.log_history[-1].get("eval_loss", None),
|
||||
"predict_loss": state.log_history[-1].get("predict_loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
|
@ -0,0 +1,7 @@
|
|||
IGNORE_INDEX = -100
|
||||
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
|
@ -0,0 +1,18 @@
|
|||
import sys
|
||||
import logging
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoid runtime error in model.generate(do_sample=True).
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 0] = 1.0
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: PreTrainedModel,
|
||||
finetuning_type: str,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> PreTrainedModel:
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
|
||||
return model
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Optional
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
last = smoothed_val
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
for i in range(len(data["log_history"])):
|
||||
if key in data["log_history"][i]:
|
||||
steps.append(data["log_history"][i]["step"])
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
|
||||
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||
state_dict = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
if not os.path.exists(valuehead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
return True
|
|
@ -0,0 +1,5 @@
|
|||
from .data_args import DataArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .general_args import GeneralArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
|
@ -0,0 +1,119 @@
|
|||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
source_prefix: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
self.response_column = "output"
|
||||
self.history_column = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_zh",
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default="data",
|
||||
metadata={"help": "The name of the folder containing datasets."}
|
||||
)
|
||||
split: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||
)
|
||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||
)
|
||||
dev_ratio: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default="default",
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if self.source_prefix is not None:
|
||||
prefix_list = self.source_prefix.split("|")
|
||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||
else:
|
||||
prefix_list = [None] * len(dataset_names)
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for i, name in enumerate(dataset_names):
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
|
@ -0,0 +1,78 @@
|
|||
import json
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_hidden_layers: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
default=32.0,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralArguments:
|
||||
"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Any, Dict, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=0.95,
|
||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=0.7,
|
||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||
)
|
||||
repetition_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||
)
|
||||
length_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", None):
|
||||
args.pop("max_length", None)
|
||||
return args
|
|
@ -0,0 +1,72 @@
|
|||
import torch
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
)
|
||||
model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||
)
|
||||
padding_side: Optional[Literal["left", "right"]] = field(
|
||||
default="left",
|
||||
metadata={"help": "The side on which the model should have padding applied."}
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
)
|
||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use in int4 training."}
|
||||
)
|
||||
double_quantization: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
|
@ -0,0 +1,5 @@
|
|||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
|
@ -0,0 +1,2 @@
|
|||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: PreTrainedModel,
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: bool,
|
||||
is_mergeable: bool
|
||||
) -> PreTrainedModel:
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze":
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
for checkpoint in checkpoints_to_merge:
|
||||
model = PeftModel.from_pretrained(model, checkpoint)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=finetuning_args.lora_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
|
@ -0,0 +1,151 @@
|
|||
import os
|
||||
import torch
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with the LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side=model_args.padding_side,
|
||||
**config_kwargs
|
||||
)
|
||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
if model_args.quantization_bit is not None:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6.0
|
||||
)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||
config_kwargs["device_map"] = "auto"
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
||||
config.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
||||
model.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
|
||||
print_trainable_params(model)
|
||||
|
||||
return model, tokenizer
|
|
@ -0,0 +1,134 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
|
||||
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
||||
|
||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True while training."
|
||||
|
||||
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
"Please enable `predict_with_generate` to save model predictions."
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
else:
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, general_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
else:
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
|
@ -1,76 +1,20 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
Seq2SeqTrainer,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
get_state_dict,
|
||||
load_trainable_params,
|
||||
load_valuehead_params,
|
||||
FINETUNING_ARGS_NAME,
|
||||
VALUE_HEAD_FILE_NAME
|
||||
)
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
|
||||
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
|
||||
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
|
||||
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
|
||||
purposes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "loss" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
log_dict = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
|
||||
f.write(json.dumps(log_dict) + "\n")
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.tuner.ppo.workflow import run_ppo
|
|
@ -2,77 +2,43 @@ import os
|
|||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerState
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
from .other import (
|
||||
AverageMeter,
|
||||
get_logger,
|
||||
get_logits_processor
|
||||
)
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||
})
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
if layer_norm_params is not None:
|
||||
param.data = layer_norm_params[name] # restore float32 weights
|
||||
else:
|
||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||
param.data = param.data.to(torch.float16)
|
||||
|
||||
return model, layer_norm_state_dict
|
||||
|
||||
|
||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: List[LogCallback],
|
||||
**kwargs
|
||||
self,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: List[LogCallback],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
|
@ -117,8 +83,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
steps_trained = 0
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
|
||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
|
||||
|
||||
for _ in range(self.config.gradient_accumulation_steps):
|
||||
|
||||
|
@ -158,6 +125,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if steps_trained == len_dataloader:
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
@ -172,20 +142,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
print(logs)
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
self.log_callback.on_log(self.args, self.state, None)
|
||||
self.log_callback.on_log(self.args, self.state, self.control)
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
||||
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
length_sampler: Optional[Callable] = None,
|
||||
return_prompt: Optional[bool] = True,
|
||||
**generation_kwargs,
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
length_sampler: Optional[Callable] = None,
|
||||
return_prompt: Optional[bool] = True,
|
||||
**generation_kwargs
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
|
||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||
})
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
if layer_norm_params is not None:
|
||||
param.data = layer_norm_params[name] # restore float32 weights
|
||||
else:
|
||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||
param.data = param.data.to(torch.float16)
|
||||
|
||||
return model, layer_norm_state_dict
|
|
@ -1,36 +1,30 @@
|
|||
# coding=utf-8
|
||||
# Implements parameter-efficient PPO training of fine-tuned models.
|
||||
# This code is inspired by:
|
||||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||
|
||||
import math
|
||||
|
||||
from torch.optim import AdamW
|
||||
from transformers.optimization import get_scheduler
|
||||
from trl import PPOConfig
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from utils import (
|
||||
PPOPeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
plot_loss
|
||||
)
|
||||
from torch.optim import AdamW
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
def run_ppo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
|
||||
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
|
@ -72,12 +66,3 @@ def main():
|
|||
ppo_trainer.save_state() # must be after save_model
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.tuner.pt.workflow import run_pt
|
|
@ -1,31 +1,28 @@
|
|||
# coding=utf-8
|
||||
# Implements several parameter-efficient pre-training method.
|
||||
# This code is inspired by
|
||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
import math
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from utils.other import IGNORE_INDEX
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
|
||||
from utils import (
|
||||
PeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
plot_loss
|
||||
)
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
|
@ -48,7 +45,7 @@ def main():
|
|||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
|
@ -65,21 +62,12 @@ def main():
|
|||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.tuner.rm.workflow import run_rm
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
from typing import Any, Dict, Sequence
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||
return super().__call__(features)
|
|
@ -0,0 +1,7 @@
|
|||
import numpy as np
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: Optional[bool] = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
We use score on the EOS token to represent reward of the whole sentence.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
|
@ -1,30 +1,28 @@
|
|||
# coding=utf-8
|
||||
# Implements parameter-efficient training of reward models.
|
||||
# This code is inspired by:
|
||||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from utils import (
|
||||
PairwiseDataCollatorWithPadding,
|
||||
PairwisePeftTrainer,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
compute_accuracy,
|
||||
plot_loss
|
||||
)
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
def run_rm(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
@ -66,12 +64,3 @@ def main():
|
|||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.tuner.sft.workflow import run_sft
|
|
@ -0,0 +1,51 @@
|
|||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
r"""
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
for pred, label in zip(decoded_preds, decoded_labels):
|
||||
hypothesis = list(jieba.cut(pred))
|
||||
reference = list(jieba.cut(label))
|
||||
|
||||
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
|
||||
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
||||
else:
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
||||
result = scores[0]
|
||||
|
||||
for k, v in result.items():
|
||||
score_dict[k].append(round(v["f"] * 100, 4))
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
|
@ -3,65 +3,17 @@ import json
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
from .other import get_logger, IGNORE_INDEX
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
r"""
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
for pred, label in zip(decoded_preds, decoded_labels):
|
||||
hypothesis = list(jieba.cut(pred))
|
||||
reference = list(jieba.cut(label))
|
||||
|
||||
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
|
||||
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
||||
else:
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
||||
result = scores[0]
|
||||
|
||||
for k, v in result.items():
|
||||
score_dict[k].append(round(v["f"] * 100, 4))
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
|
||||
|
||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
|
@ -80,7 +32,10 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
||||
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
|
||||
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
||||
else:
|
||||
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
|
@ -89,8 +44,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||
return (loss, generated_tokens, labels)
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
|
@ -1,31 +1,29 @@
|
|||
# coding=utf-8
|
||||
# Implements several parameter-efficient supervised fine-tuning method.
|
||||
# This code is inspired by
|
||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from utils.other import IGNORE_INDEX
|
||||
from utils import (
|
||||
Seq2SeqPeftTrainer,
|
||||
ComputeMetrics,
|
||||
LogCallback,
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
get_logits_processor,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
|
@ -54,7 +52,7 @@ def main():
|
|||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
@ -94,12 +92,3 @@ def main():
|
|||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,23 @@
|
|||
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||
|
||||
|
||||
def main():
|
||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
|
||||
|
||||
if general_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args)
|
||||
elif general_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args)
|
||||
elif general_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args)
|
||||
elif general_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,17 +0,0 @@
|
|||
from .common import (
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_infer_args,
|
||||
prepare_data,
|
||||
preprocess_data
|
||||
)
|
||||
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
|
||||
from .ppo import PPOPeftTrainer
|
||||
|
||||
from .template import Template
|
||||
|
||||
from .other import get_logits_processor, plot_loss
|
|
@ -1,619 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import hashlib
|
||||
from itertools import chain
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainingArguments,
|
||||
BitsAndBytesConfig
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
|
||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from .config import (
|
||||
ModelArguments,
|
||||
DataTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
from .template import Template
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
load_trainable_params,
|
||||
load_valuehead_params,
|
||||
print_trainable_params,
|
||||
prepare_model_for_training,
|
||||
IGNORE_INDEX
|
||||
)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _init_adapter(
|
||||
model: PreTrainedModel,
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: bool,
|
||||
is_mergeable: bool
|
||||
) -> PreTrainedModel:
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze":
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
for checkpoint in checkpoints_to_merge:
|
||||
model = PeftModel.from_pretrained(model, checkpoint)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=finetuning_args.lora_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_pretrained(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with the LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side=model_args.padding_side,
|
||||
**config_kwargs
|
||||
)
|
||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
if model_args.quantization_bit is not None:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6.0
|
||||
)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||
config_kwargs["device_map"] = "auto"
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
||||
config.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
||||
model.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
|
||||
print_trainable_params(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_args(
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
|
||||
assert stage == "sft" or (not training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
||||
|
||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True while training."
|
||||
|
||||
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
"Please enable `predict_with_generate` to save model predictions."
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
else:
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args
|
||||
|
||||
|
||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
else:
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def prepare_data(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataTrainingArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
||||
if sha1 != hash:
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
ext2type = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path = None
|
||||
data_files: List[str] = []
|
||||
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
|
||||
if data_path is None:
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||
|
||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
data_path,
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
dataset = raw_datasets[data_args.split]
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
("query_column", "query"),
|
||||
("response_column", "response"),
|
||||
("history_column", "history")
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
if getattr(dataset_attr, column_name) != target_name:
|
||||
if getattr(dataset_attr, column_name):
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
all_datasets = all_datasets[0]
|
||||
else:
|
||||
all_datasets = concatenate_datasets(all_datasets)
|
||||
|
||||
return all_datasets
|
||||
|
||||
|
||||
def preprocess_data(
|
||||
dataset: Dataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataTrainingArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def get_dialog(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||
yield dialog
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
block_size = data_args.max_source_length - 1
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
||||
for i in range(0, total_length, block_size)]
|
||||
return {
|
||||
"input_ids": result,
|
||||
"labels": result.copy()
|
||||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
||||
break
|
||||
|
||||
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["labels"].append(target_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||
skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(dataset[0])
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
|
||||
return dataset
|
|
@ -1,312 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
source_prefix: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
self.response_column = "output"
|
||||
self.history_column = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
)
|
||||
model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||
)
|
||||
padding_side: Optional[Literal["left", "right"]] = field(
|
||||
default="left",
|
||||
metadata={"help": "The side on which the model should have padding applied."}
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
)
|
||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use in int4 training."}
|
||||
)
|
||||
double_quantization: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_zh",
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default="data",
|
||||
metadata={"help": "The name of the folder containing datasets."}
|
||||
)
|
||||
split: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||
)
|
||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||
)
|
||||
dev_ratio: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default="default",
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if self.source_prefix is not None:
|
||||
prefix_list = self.source_prefix.split("|")
|
||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||
else:
|
||||
prefix_list = [None] * len(dataset_names)
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for i, name in enumerate(dataset_names):
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_hidden_layers: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
default=32.0,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=0.95,
|
||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=0.7,
|
||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||
)
|
||||
repetition_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||
)
|
||||
length_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", None):
|
||||
args.pop("max_length", None)
|
||||
return args
|
|
@ -1,197 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoid runtime error in model.generate(do_sample=True).
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 0] = 1.0
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: PreTrainedModel,
|
||||
finetuning_type: str,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
||||
) -> PreTrainedModel:
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||
state_dict = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
if not os.path.exists(valuehead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
return True
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
last = smoothed_val
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
for i in range(len(data["log_history"])):
|
||||
if key in data["log_history"][i]:
|
||||
steps.append(data["log_history"][i]["step"])
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
|
@ -1,60 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
from .other import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
||||
|
||||
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
We use score on the EOS token to represent reward of the whole sentence.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
|
@ -2,83 +2,26 @@
|
|||
# Implements user interface in browser for fine-tuned models.
|
||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from threading import Thread
|
||||
from utils import (
|
||||
Template,
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||
|
||||
|
||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||
|
||||
|
||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
|
||||
def postprocess(self, y):
|
||||
r"""
|
||||
Overrides Chatbot.postprocess
|
||||
"""
|
||||
if y is None:
|
||||
return []
|
||||
for i, (message, response) in enumerate(y):
|
||||
y[i] = (
|
||||
None if message is None else mdtex2html.convert((message)),
|
||||
None if response is None else mdtex2html.convert(response),
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
gr.Chatbot.postprocess = postprocess
|
||||
|
||||
|
||||
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if line != ""]
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "```" in line:
|
||||
count += 1
|
||||
items = line.split("`")
|
||||
if count % 2 == 1:
|
||||
lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
|
||||
else:
|
||||
lines[i] = "<br /></code></pre>"
|
||||
else:
|
||||
if i > 0:
|
||||
if count % 2 == 1:
|
||||
line = line.replace("`", "\`")
|
||||
line = line.replace("<", "<")
|
||||
line = line.replace(">", ">")
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace("*", "*")
|
||||
line = line.replace("_", "_")
|
||||
line = line.replace("-", "-")
|
||||
line = line.replace(".", ".")
|
||||
line = line.replace("!", "!")
|
||||
line = line.replace("(", "(")
|
||||
line = line.replace(")", ")")
|
||||
line = line.replace("$", "$")
|
||||
lines[i] = "<br />" + line
|
||||
text = "".join(lines)
|
||||
return text
|
||||
|
||||
|
||||
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
||||
chatbot.append((parse_text(query), ""))
|
||||
chatbot.append((query, ""))
|
||||
|
||||
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
@ -102,7 +45,7 @@ def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
|||
for new_text in streamer:
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = (parse_text(query), parse_text(response))
|
||||
chatbot[-1] = (query, response)
|
||||
yield chatbot, new_history
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue