simplify code
This commit is contained in:
parent
d1d8e8bae1
commit
67a2773074
|
@ -5,9 +5,13 @@
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import create_app
|
from llmtuner.api.app import create_app
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app = create_app()
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app = create_app()
|
main()
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
# Implements stream chat in command line for fine-tuned models.
|
# Implements stream chat in command line for fine-tuned models.
|
||||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
|
||||||
from llmtuner import ChatModel, get_infer_args
|
from llmtuner import ChatModel
|
||||||
|
from llmtuner.tuner import get_infer_args
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# Exports the fine-tuned model.
|
# Exports the fine-tuned model.
|
||||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||||
|
|
||||||
from llmtuner import get_train_args, load_model_and_tokenizer
|
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
from llmtuner.api import create_app
|
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
|
|
||||||
from llmtuner.webui import create_ui
|
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.1.1"
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.api.app import create_app
|
|
|
@ -4,7 +4,7 @@ from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.template import Template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
from llmtuner.tuner import load_model_and_tokenizer
|
from llmtuner.tuner import load_model_and_tokenizer
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class ChatModel:
|
||||||
generating_args: GeneratingArguments
|
generating_args: GeneratingArguments
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
self.template = Template(data_args.prompt_template)
|
self.template = get_template(data_args.prompt_template)
|
||||||
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
from llmtuner.dsets.loader import get_dataset
|
from llmtuner.dsets.loader import get_dataset
|
||||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||||
|
from llmtuner.dsets.utils import split_dataset
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
|
||||||
TrainerState,
|
|
||||||
TrainingArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
|
||||||
self.runner = runner
|
|
||||||
self.start_time = time.time()
|
|
||||||
self.tracker = {}
|
|
||||||
|
|
||||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
|
||||||
might take several inputs.
|
|
||||||
"""
|
|
||||||
if self.runner is not None and self.runner.aborted:
|
|
||||||
control.should_epoch_stop = True
|
|
||||||
control.should_training_stop = True
|
|
||||||
|
|
||||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the end of an substep during gradient accumulation.
|
|
||||||
"""
|
|
||||||
if self.runner is not None and self.runner.aborted:
|
|
||||||
control.should_epoch_stop = True
|
|
||||||
control.should_training_stop = True
|
|
||||||
|
|
||||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
|
||||||
r"""
|
|
||||||
Event called after logging the last logs.
|
|
||||||
"""
|
|
||||||
if "loss" not in state.log_history[-1]:
|
|
||||||
return
|
|
||||||
cur_time = time.time()
|
|
||||||
cur_steps = state.log_history[-1].get("step")
|
|
||||||
elapsed_time = cur_time - self.start_time
|
|
||||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
|
||||||
remaining_steps = state.max_steps - cur_steps
|
|
||||||
remaining_time = remaining_steps * avg_time_per_step
|
|
||||||
self.tracker = {
|
|
||||||
"current_steps": cur_steps,
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"loss": state.log_history[-1].get("loss", None),
|
|
||||||
"reward": state.log_history[-1].get("reward", None),
|
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
|
||||||
"epoch": state.log_history[-1].get("epoch", None),
|
|
||||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
|
||||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
|
||||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
|
||||||
}
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(self.tracker) + "\n")
|
|
|
@ -6,7 +6,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.template import Template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ def preprocess_dataset(
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
|
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(dataset.column_names)
|
||||||
prompt_template = Template(data_args.prompt_template)
|
prompt_template = get_template(data_args.prompt_template)
|
||||||
|
|
||||||
# support question with a single answer or multiple answers
|
# support question with a single answer or multiple answers
|
||||||
def get_dialog(examples):
|
def get_dialog(examples):
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
from typing import Dict
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(
|
||||||
|
dataset: Dataset, dev_ratio: float, do_train: bool
|
||||||
|
) -> Dict[str, Dataset]:
|
||||||
|
# Split the dataset
|
||||||
|
if do_train:
|
||||||
|
if dev_ratio > 1e-6:
|
||||||
|
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||||
|
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||||
|
else:
|
||||||
|
return {"train_dataset": dataset}
|
||||||
|
else: # do_eval or do_predict
|
||||||
|
return {"eval_dataset": dataset}
|
|
@ -3,30 +3,13 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Format:
|
class Template:
|
||||||
|
|
||||||
prefix: str
|
prefix: str
|
||||||
prompt: str
|
prompt: str
|
||||||
sep: str
|
sep: str
|
||||||
use_history: bool
|
use_history: bool
|
||||||
|
|
||||||
|
|
||||||
templates: Dict[str, Format] = {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Template:
|
|
||||||
|
|
||||||
name: str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.name in templates:
|
|
||||||
self.prefix = templates[self.name].prefix
|
|
||||||
self.prompt = templates[self.name].prompt
|
|
||||||
self.sep = templates[self.name].sep
|
|
||||||
self.use_history = templates[self.name].use_history
|
|
||||||
else:
|
|
||||||
raise ValueError("Template {} does not exist.".format(self.name))
|
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -61,8 +44,11 @@ class Template:
|
||||||
return convs[:-1] # drop last
|
return convs[:-1] # drop last
|
||||||
|
|
||||||
|
|
||||||
|
templates: Dict[str, Template] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||||
templates[name] = Format(
|
templates[name] = Template(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
sep=sep,
|
sep=sep,
|
||||||
|
@ -70,6 +56,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_template(name: str) -> Template:
|
||||||
|
template = templates.get(name, None)
|
||||||
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports language model inference without histories.
|
Supports language model inference without histories.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,7 +4,7 @@ import math
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
@ -28,16 +28,6 @@ def run_pt(
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Split the dataset
|
|
||||||
if training_args.do_train:
|
|
||||||
if data_args.dev_ratio > 1e-6:
|
|
||||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
|
||||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
|
||||||
else:
|
|
||||||
trainer_kwargs = {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
trainer_kwargs = {"eval_dataset": dataset}
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PeftTrainer(
|
trainer = PeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
@ -46,7 +36,7 @@ def run_pt(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**trainer_kwargs
|
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
@ -29,16 +29,6 @@ def run_rm(
|
||||||
|
|
||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
# Split the dataset
|
|
||||||
if training_args.do_train:
|
|
||||||
if data_args.dev_ratio > 1e-6:
|
|
||||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
|
||||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
|
||||||
else:
|
|
||||||
trainer_kwargs = {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
trainer_kwargs = {"eval_dataset": dataset}
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PairwisePeftTrainer(
|
trainer = PairwisePeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
@ -48,7 +38,7 @@ def run_rm(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**trainer_kwargs
|
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
|
@ -35,16 +35,6 @@ def run_sft(
|
||||||
training_args.generation_num_beams = data_args.eval_num_beams if \
|
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||||
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||||
|
|
||||||
# Split the dataset
|
|
||||||
if training_args.do_train:
|
|
||||||
if data_args.dev_ratio > 1e-6:
|
|
||||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
|
||||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
|
||||||
else:
|
|
||||||
trainer_kwargs = {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
trainer_kwargs = {"eval_dataset": dataset}
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Seq2SeqPeftTrainer(
|
trainer = Seq2SeqPeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
@ -54,7 +44,7 @@ def run_sft(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**trainer_kwargs
|
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from llmtuner.webui.interface import create_ui
|
|
|
@ -1,4 +1,4 @@
|
||||||
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from llmtuner import create_ui
|
from llmtuner.webui.interface import create_ui
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner import get_infer_args
|
from llmtuner.tuner import get_infer_args
|
||||||
from llmtuner.webui.chat import WebChatModel
|
from llmtuner.webui.chat import WebChatModel
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
from llmtuner.webui.components.chatbot import create_chat_box
|
||||||
from llmtuner.webui.manager import Manager
|
from llmtuner.webui.manager import Manager
|
||||||
|
|
Loading…
Reference in New Issue