simplify code

This commit is contained in:
hiyouga 2023-07-20 15:08:57 +08:00
parent d1d8e8bae1
commit 67a2773074
18 changed files with 52 additions and 136 deletions

View File

@ -5,9 +5,13 @@
import uvicorn import uvicorn
from llmtuner import create_app from llmtuner.api.app import create_app
def main():
app = create_app()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
if __name__ == "__main__": if __name__ == "__main__":
app = create_app() main()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@ -2,7 +2,8 @@
# Implements stream chat in command line for fine-tuned models. # Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel, get_infer_args from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args
def main(): def main():

View File

@ -2,7 +2,7 @@
# Exports the fine-tuned model. # Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner import get_train_args, load_model_and_tokenizer from llmtuner.tuner import get_train_args, load_model_and_tokenizer
def main(): def main():

View File

@ -1,7 +1,4 @@
from llmtuner.api import create_app
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
from llmtuner.webui import create_ui
__version__ = "0.1.1" __version__ = "0.1.1"

View File

@ -1 +0,0 @@
from llmtuner.api.app import create_app

View File

@ -4,7 +4,7 @@ from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import Template from llmtuner.extras.template import get_template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer from llmtuner.tuner import load_model_and_tokenizer
@ -19,7 +19,7 @@ class ChatModel:
generating_args: GeneratingArguments generating_args: GeneratingArguments
) -> None: ) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.template = Template(data_args.prompt_template) self.template = get_template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix if data_args.source_prefix else "" self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
self.generating_args = generating_args self.generating_args = generating_args

View File

@ -1,2 +1,3 @@
from llmtuner.dsets.loader import get_dataset from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@ -1,63 +0,0 @@
import os
import json
import time
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.start_time = time.time()
self.tracker = {}
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if "loss" not in state.log_history[-1]:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")

View File

@ -6,7 +6,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import Template from llmtuner.extras.template import get_template
from llmtuner.hparams import DataArguments from llmtuner.hparams import DataArguments
@ -19,7 +19,7 @@ def preprocess_dataset(
) -> Dataset: ) -> Dataset:
column_names = list(dataset.column_names) column_names = list(dataset.column_names)
prompt_template = Template(data_args.prompt_template) prompt_template = get_template(data_args.prompt_template)
# support question with a single answer or multiple answers # support question with a single answer or multiple answers
def get_dialog(examples): def get_dialog(examples):

View File

@ -0,0 +1,16 @@
from typing import Dict
from datasets import Dataset
def split_dataset(
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
if do_train:
if dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@ -3,30 +3,13 @@ from dataclasses import dataclass
@dataclass @dataclass
class Format: class Template:
prefix: str prefix: str
prompt: str prompt: str
sep: str sep: str
use_history: bool use_history: bool
templates: Dict[str, Format] = {}
@dataclass
class Template:
name: str
def __post_init__(self):
if self.name in templates:
self.prefix = templates[self.name].prefix
self.prompt = templates[self.name].prompt
self.sep = templates[self.name].sep
self.use_history = templates[self.name].use_history
else:
raise ValueError("Template {} does not exist.".format(self.name))
def get_prompt( def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> str: ) -> str:
@ -61,8 +44,11 @@ class Template:
return convs[:-1] # drop last return convs[:-1] # drop last
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None: def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Format( templates[name] = Template(
prefix=prefix, prefix=prefix,
prompt=prompt, prompt=prompt,
sep=sep, sep=sep,
@ -70,6 +56,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
) )
def get_template(name: str) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
return template
r""" r"""
Supports language model inference without histories. Supports language model inference without histories.
""" """

View File

@ -4,7 +4,7 @@ import math
from typing import Optional, List from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
@ -28,16 +28,6 @@ def run_pt(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
) )
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = PeftTrainer(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
@ -46,7 +36,7 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**trainer_kwargs **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
) )
# Training # Training

View File

@ -5,7 +5,7 @@
from typing import Optional, List from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
@ -29,16 +29,6 @@ def run_rm(
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer # Initialize our Trainer
trainer = PairwisePeftTrainer( trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
@ -48,7 +38,7 @@ def run_rm(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**trainer_kwargs **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
) )
# Training # Training

View File

@ -3,7 +3,7 @@
from typing import Optional, List from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
@ -35,16 +35,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams if \ training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqPeftTrainer( trainer = Seq2SeqPeftTrainer(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
@ -54,7 +44,7 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`

View File

@ -1 +0,0 @@
from llmtuner.webui.interface import create_ui

View File

@ -1,4 +1,4 @@
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
def main(): def main():

View File

@ -1,4 +1,4 @@
from llmtuner import create_ui from llmtuner.webui.interface import create_ui
def main(): def main():

View File

@ -5,7 +5,7 @@
import gradio as gr import gradio as gr
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from llmtuner import get_infer_args from llmtuner.tuner import get_infer_args
from llmtuner.webui.chat import WebChatModel from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager from llmtuner.webui.manager import Manager