diff --git a/README.md b/README.md index 75781f1b..ddbd148d 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,9 @@ ## Changelog -[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base`, `--padding_side right` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model. +[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. + +[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. @@ -125,14 +127,10 @@ cd LLaMA-Efficient-Tuning pip install -r requirements.txt ``` -### LLaMA Weights Preparation (optional) - -1. Download the weights of the LLaMA models. -2. Convert them to HF format using the following command. +### All-in-one Web UI ```bash -python -m transformers.models.llama.convert_llama_weights_to_hf \ - --input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model +python src/train_web.py ``` ### (Continually) Pre-Training @@ -275,10 +273,20 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. -### API / CLI / Web Demo +### API Demo ```bash -python src/xxx_demo.py \ +python src/api_demo.py \ + --model_name_or_path path_to_your_model \ + --checkpoint_dir path_to_checkpoint +``` + +See `http://localhost:8000/docs` for API documentation. + +### CLI Demo + +```bash +python src/cli_demo.py \ --model_name_or_path path_to_your_model \ --checkpoint_dir path_to_checkpoint ``` diff --git a/requirements.txt b/requirements.txt index 3b6a62c6..e7f5bf16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,14 +3,14 @@ transformers>=4.29.1 datasets>=2.12.0 accelerate>=0.19.0 peft>=0.3.0 -trl==0.4.4 +trl>=0.4.7 sentencepiece jieba rouge-chinese nltk gradio>=3.36.0 uvicorn -pydantic==1.10.7 +pydantic fastapi sse-starlette matplotlib diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 9785981a..b53af4de 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,6 +1,7 @@ from llmtuner.api import create_app from llmtuner.chat import ChatModel from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo +from llmtuner.webui import create_ui -__version__ = "0.0.9" +__version__ = "0.1.0" diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 12a4d95c..027a2ec8 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,3 +1,4 @@ +import json import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -93,7 +94,7 @@ def create_app(): finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") - yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield json.dumps(chunk, ensure_ascii=False) for new_text in chat_model.stream_chat( query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens @@ -107,7 +108,7 @@ def create_app(): finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") - yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield json.dumps(chunk, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( index=0, @@ -115,7 +116,7 @@ def create_app(): finish_reason="stop" ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") - yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield json.dumps(chunk, ensure_ascii=False) yield "[DONE]" return app diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index ab8971bb..be7d119d 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -5,3 +5,27 @@ VALUE_HEAD_FILE_NAME = "value_head.bin" FINETUNING_ARGS_NAME = "finetuning_args.json" LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings + +METHODS = ["full", "freeze", "lora"] + +SUPPORTED_MODELS = { + "LLaMA-7B": "huggyllama/llama-7b", + "LLaMA-13B": "huggyllama/llama-13b", + "LLaMA-30B": "huggyllama/llama-30b", + "LLaMA-65B": "huggyllama/llama-65b", + "BLOOM-560M": "bigscience/bloom-560m", + "BLOOM-3B": "bigscience/bloom-3b", + "BLOOM-7B1": "bigscience/bloom-7b1", + "BLOOMZ-560M": "bigscience/bloomz-560m", + "BLOOMZ-3B": "bigscience/bloomz-3b", + "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", + "Falcon-7B-Base": "tiiuae/falcon-7b", + "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", + "Falcon-40B-Base": "tiiuae/falcon-40b", + "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", + "Baichuan-7B": "baichuan-inc/Baichuan-7B", + "Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base", + "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", + "InternLM-7B-Base": "internlm/internlm-7b", + "InternLM-7B-Chat": "internlm/internlm-chat-7b" +} diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py index 231acf4a..4b4f647e 100644 --- a/src/llmtuner/extras/logging.py +++ b/src/llmtuner/extras/logging.py @@ -2,6 +2,20 @@ import sys import logging +class LoggerHandler(logging.Handler): + + def __init__(self): + super().__init__() + self.log = "" + + def emit(self, record): + if record.name == "httpx": + return + log_entry = self.format(record) + self.log += log_entry + self.log += "\n\n" + + def get_logger(name: str) -> logging.Logger: formatter = logging.Formatter( diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py index fb11a290..82530e45 100644 --- a/src/llmtuner/extras/ploting.py +++ b/src/llmtuner/extras/ploting.py @@ -1,4 +1,5 @@ import os +import math import json import matplotlib.pyplot as plt from typing import List, Optional @@ -10,12 +11,13 @@ from llmtuner.extras.logging import get_logger logger = get_logger(__name__) -def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: +def smooth(scalars: List[float]) -> List[float]: r""" EMA implementation according to TensorBoard. """ last = scalars[0] smoothed = list() + weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function for next_val in scalars: smoothed_val = last * weight + (1 - weight) * next_val smoothed.append(smoothed_val) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 88469d5c..49b20893 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,141 +1,29 @@ -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from dataclasses import dataclass +@dataclass +class Format: + prefix: str + prompt: str + sep: str + use_history: bool + + +templates: Dict[str, Format] = {} + + @dataclass class Template: name: str def __post_init__(self): - - if self.name == "vanilla": - r""" - Supports language model inference without histories. - """ - self._register_template( - prefix="", - prompt="{query}", - sep="", - use_history=False - ) - - elif self.name == "default": - r""" - Default template. - """ - self._register_template( - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True - ) - - elif self.name == "alpaca": - r""" - Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff - https://github.com/ymcui/Chinese-LLaMA-Alpaca - """ - self._register_template( - prefix="Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.", - prompt="### Instruction:\n{query}\n\n### Response:\n", - sep="\n\n", - use_history=True - ) - - elif self.name == "vicuna": - r""" - Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 - https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 - """ - self._register_template( - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="USER: {query} ASSISTANT: ", - sep="", - use_history=True - ) - - elif self.name == "belle": - r""" - Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B - """ - self._register_template( - prefix="", - prompt="Human: {query}\n\nBelle: ", - sep="\n\n", - use_history=True - ) - - elif self.name == "linly": - r""" - Supports: https://github.com/CVI-SZU/Linly - """ - self._register_template( - prefix="", - prompt="User: {query}\nBot: ", - sep="\n", - use_history=True - ) - - elif self.name == "billa": - r""" - Supports: https://github.com/Neutralzz/BiLLa - """ - self._register_template( - prefix="", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True - ) - - elif self.name == "ziya": - r""" - Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 - """ - self._register_template( - prefix="", - prompt=":{query}\n:", - sep="\n", - use_history=True - ) - - elif self.name == "aquila": - r""" - Supports: https://huggingface.co/qhduan/aquilachat-7b - """ - self._register_template( - prefix="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", - prompt="Human: {query}###Assistant: ", - sep="###", - use_history=True - ) - - elif self.name == "intern": - r""" - Supports: https://huggingface.co/internlm/internlm-chat-7b - """ - self._register_template( - prefix="", - prompt="<|User|>:{query}\n<|Bot|>:", - sep="\n", - use_history=True - ) - - elif self.name == "baichuan": - r""" - Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat - """ - self._register_template( - prefix="", - prompt="{query}", - sep="", - use_history=True - ) - + if self.name in templates: + self.prefix = templates[self.name].prefix + self.prompt = templates[self.name].prompt + self.sep = templates[self.name].sep + self.use_history = templates[self.name].use_history else: raise ValueError("Template {} does not exist.".format(self.name)) @@ -155,14 +43,6 @@ class Template: """ return self._format_example(query, history, prefix) + [resp] - def _register_template( - self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True - ) -> None: - self.prefix = prefix - self.prompt = prompt - self.sep = sep - self.use_history = use_history - def _format_example( self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" ) -> List[str]: @@ -179,3 +59,150 @@ class Template: convs.append(self.sep + self.prompt.format(query=user_query)) convs.append(bot_resp) return convs[:-1] # drop last + + +def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None: + templates[name] = Format( + prefix=prefix, + prompt=prompt, + sep=sep, + use_history=use_history + ) + + +r""" +Supports language model inference without histories. +""" +register_template( + name="vanilla", + prefix="", + prompt="{query}", + sep="", + use_history=False +) + + +r""" +Default template. +""" +register_template( + name="default", + prefix="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + prompt="Human: {query}\nAssistant: ", + sep="\n", + use_history=True +) + + +r""" +Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff + https://github.com/ymcui/Chinese-LLaMA-Alpaca +""" +register_template( + name="alpaca", + prefix="Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.", + prompt="### Instruction:\n{query}\n\n### Response:\n", + sep="\n\n", + use_history=True +) + + +r""" +Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 + https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 +""" +register_template( + name="vicuna", + prefix="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + prompt="USER: {query} ASSISTANT: ", + sep="", + use_history=True +) + + +r""" +Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B +""" +register_template( + name="belle", + prefix="", + prompt="Human: {query}\n\nBelle: ", + sep="\n\n", + use_history=True +) + + +r""" +Supports: https://github.com/CVI-SZU/Linly +""" +register_template( + name="linly", + prefix="", + prompt="User: {query}\nBot: ", + sep="\n", + use_history=True +) + + +r""" +Supports: https://github.com/Neutralzz/BiLLa +""" +register_template( + name="billa", + prefix="", + prompt="Human: {query}\nAssistant: ", + sep="\n", + use_history=True +) + + +r""" +Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 +""" +register_template( + name="ziya", + prefix="", + prompt=":{query}\n:", + sep="\n", + use_history=True +) + + +r""" +Supports: https://huggingface.co/qhduan/aquilachat-7b +""" +register_template( + name="aquila", + prefix="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + prompt="Human: {query}###Assistant: ", + sep="###", + use_history=True +) + + +r""" +Supports: https://huggingface.co/internlm/internlm-chat-7b +""" +register_template( + name="intern", + prefix="", + prompt="<|User|>:{query}\n<|Bot|>:", + sep="\n", + use_history=True +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +""" +register_template( + name="baichuan", + prefix="", + prompt="{query}", + sep="", + use_history=True +) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index c4f3ac33..a111a8c5 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -28,7 +28,7 @@ check_min_version("4.29.1") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") -require_version("trl==0.4.4", "To fix: pip install trl==0.4.4") +require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") def load_model_and_tokenizer( diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 30afe556..7a05dadb 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -25,7 +25,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): r""" Inherits PPOTrainer. """ - def __init__( self, training_args: Seq2SeqTrainingArguments, @@ -46,12 +45,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. """ - total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size + total_train_batch_size = ( + self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size + ) len_dataloader = len(self.dataloader) - num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1) num_examples = len(self.dataset) num_train_epochs = self.args.num_train_epochs - max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) + max_steps = math.ceil(num_train_epochs * len_dataloader) self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs @@ -62,9 +62,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {self.config.batch_size}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") - logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") @@ -77,7 +77,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): "eos_token_id": self.tokenizer.eos_token_id, "logits_processor": get_logits_processor() } - output_length_sampler = LengthSampler(max_target_length // 2, max_target_length) + length_sampler = LengthSampler(max_target_length // 2, max_target_length) unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) @@ -87,59 +87,45 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self.log_callback.on_train_begin(self.args, self.state, self.control) for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): + batch = next(dataiter) + steps_trained += 1 - for _ in range(self.config.gradient_accumulation_steps): + unwrapped_model.gradient_checkpointing_disable() + unwrapped_model.config.use_cache = True - batch = next(dataiter) - steps_trained += 1 + # Get responses + query_tensors = batch["input_ids"] + response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs) - unwrapped_model.gradient_checkpointing_disable() - unwrapped_model.config.use_cache = True + queries, responses = [], [] + for i in range(len(query_tensors)): + query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0] + response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 + queries.append(query_tensors[i, query_length:]) # remove padding from left + responses.append(response_tensors[i, :response_length]) # remove padding from right - # Get response from model - query_tensors: torch.Tensor = batch["input_ids"] - response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs) - - queries: List[torch.Tensor] = [] - responses: List[torch.Tensor] = [] - for i in range(len(query_tensors)): - query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0] - response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 - queries.append(query_tensors[i, query_length:]) # remove padding from left - if response_length < 2: # make response have at least 2 tokens - responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id)) - else: - responses.append(response_tensors[i, :response_length]) # remove padding from right - - # Compute rewards - replace_model(unwrapped_model, target="reward") + # Compute rewards + replace_model(unwrapped_model, target="reward") + with torch.no_grad(): _, _, values = self.model(**self.prepare_model_inputs(queries, responses)) - rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type - replace_model(unwrapped_model, target="default") # make sure the model is default at the end + rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type + replace_model(unwrapped_model, target="default") - # Run PPO step - unwrapped_model.gradient_checkpointing_enable() - unwrapped_model.config.use_cache = False + # Run PPO step + unwrapped_model.gradient_checkpointing_enable() + unwrapped_model.config.use_cache = False + stats = self.step(queries, responses, rewards) - stats = self.step(queries, responses, rewards) - - loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) - reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) - - if self.control.should_epoch_stop or self.control.should_training_stop: - break - - if steps_trained == len_dataloader: - dataiter = iter(self.dataloader) - steps_trained = 0 + loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) + reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: - logs = { - "loss": round(loss_meter.avg, 4), - "reward": round(reward_meter.avg, 4), - "learning_rate": stats["ppo/learning_rate"], - "epoch": round(step / num_steps_per_epoch, 2) - } + logs = dict( + loss=round(loss_meter.avg, 4), + reward=round(reward_meter.avg, 4), + learning_rate=stats["ppo/learning_rate"], + epoch=round(step / len_dataloader, 2) + ) print(logs) logs["step"] = step self.state.log_history.append(logs) @@ -150,10 +136,14 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): if (step+1) % self.args.save_steps == 0: # save checkpoint self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) - if self.control.should_training_stop: + if self.control.should_epoch_stop or self.control.should_training_stop: break - @torch.inference_mode() + if steps_trained == len_dataloader: + dataiter = iter(self.dataloader) + steps_trained = 0 + + @torch.no_grad() def generate( self, inputs: Dict[str, torch.Tensor], diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 1f63cdaa..1257fd76 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -4,7 +4,8 @@ import math from trl import PPOConfig from torch.optim import AdamW -from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments +from typing import Optional, List +from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback from transformers.optimization import get_scheduler from llmtuner.dsets import get_dataset, preprocess_dataset @@ -19,7 +20,8 @@ def run_ppo( model_args: ModelArguments, data_args: DataArguments, training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments + finetuning_args: FinetuningArguments, + callbacks: Optional[List[TrainerCallback]] = [LogCallback()] ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") @@ -30,7 +32,7 @@ def run_ppo( model_name=model_args.model_name_or_path, learning_rate=training_args.learning_rate, mini_batch_size=training_args.per_device_train_batch_size, - batch_size=training_args.per_device_train_batch_size, + batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps, ppo_epochs=1, max_grad_norm=training_args.max_grad_norm @@ -50,7 +52,7 @@ def run_ppo( ppo_trainer = PPOPeftTrainer( training_args=training_args, finetuning_args=finetuning_args, - callbacks=[LogCallback()], + callbacks=callbacks, config=ppo_config, model=model, ref_model=None, diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index db81500f..cc0835ad 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -2,7 +2,8 @@ # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py -from transformers import Seq2SeqTrainingArguments +from typing import Optional, List +from transformers import Seq2SeqTrainingArguments, TrainerCallback from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.extras.callbacks import LogCallback @@ -18,7 +19,8 @@ def run_rm( model_args: ModelArguments, data_args: DataArguments, training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments + finetuning_args: FinetuningArguments, + callbacks: Optional[List[TrainerCallback]] = [LogCallback()] ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") @@ -44,7 +46,7 @@ def run_rm( args=training_args, tokenizer=tokenizer, data_collator=data_collator, - callbacks=[LogCallback()], + callbacks=callbacks, compute_metrics=compute_accuracy, **trainer_kwargs ) diff --git a/src/llmtuner/webui/__init__.py b/src/llmtuner/webui/__init__.py new file mode 100644 index 00000000..686cc95f --- /dev/null +++ b/src/llmtuner/webui/__init__.py @@ -0,0 +1 @@ +from llmtuner.webui.interface import create_ui diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py new file mode 100644 index 00000000..71018c31 --- /dev/null +++ b/src/llmtuner/webui/chat.py @@ -0,0 +1,79 @@ +import os +from typing import List, Tuple + +from llmtuner.chat.stream_chat import ChatModel +from llmtuner.extras.misc import torch_gc +from llmtuner.hparams import GeneratingArguments +from llmtuner.tuner import get_infer_args +from llmtuner.webui.common import get_model_path, get_save_dir +from llmtuner.webui.locales import ALERTS + + +class WebChatModel(ChatModel): + + def __init__(self): + self.model = None + self.tokenizer = None + self.generating_args = GeneratingArguments() + + def load_model( + self, lang: str, model_name: str, checkpoints: list, + finetuning_type: str, template: str, quantization_bit: str + ): + if self.model is not None: + yield ALERTS["err_exists"][lang] + return + + if not model_name: + yield ALERTS["err_no_model"][lang] + return + + model_name_or_path = get_model_path(model_name) + if not model_name_or_path: + yield ALERTS["err_no_path"][lang] + return + + if checkpoints: + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) + else: + checkpoint_dir = None + + yield ALERTS["info_loading"][lang] + args = dict( + model_name_or_path=model_name_or_path, + finetuning_type=finetuning_type, + prompt_template=template, + checkpoint_dir=checkpoint_dir, + quantization_bit=int(quantization_bit) if quantization_bit else None + ) + super().__init__(*get_infer_args(args)) + + yield ALERTS["info_loaded"][lang] + + def unload_model(self, lang: str): + yield ALERTS["info_unloading"][lang] + self.model = None + self.tokenizer = None + torch_gc() + yield ALERTS["info_unloaded"][lang] + + def predict( + self, + chatbot: List[Tuple[str, str]], + query: str, + history: List[Tuple[str, str]], + max_new_tokens: int, + top_p: float, + temperature: float + ): + chatbot.append([query, ""]) + response = "" + for new_text in self.stream_chat( + query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + ): + response += new_text + new_history = history + [(query, response)] + chatbot[-1] = [query, response] + yield chatbot, new_history diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py new file mode 100644 index 00000000..bf1d18fb --- /dev/null +++ b/src/llmtuner/webui/common.py @@ -0,0 +1,75 @@ +import json +import os +from typing import Any, Dict, Optional + +import gradio as gr +from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME +from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME + +from llmtuner.extras.constants import SUPPORTED_MODELS + + +DEFAULT_CACHE_DIR = "cache" +DEFAULT_DATA_DIR = "data" +DEFAULT_SAVE_DIR = "saves" +USER_CONFIG = "user.config" +DATA_CONFIG = "dataset_info.json" + + +def get_save_dir(model_name: str) -> str: + return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1]) + + +def get_config_path() -> os.PathLike: + return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) + + +def load_config() -> Dict[str, Any]: + try: + with open(get_config_path(), "r", encoding="utf-8") as f: + return json.load(f) + except: + return {"last_model": "", "path_dict": {}} + + +def save_config(model_name: str, model_path: str) -> None: + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + user_config = load_config() + user_config["last_model"] = model_name + user_config["path_dict"][model_name] = model_path + with open(get_config_path(), "w", encoding="utf-8") as f: + json.dump(user_config, f, indent=2, ensure_ascii=False) + + +def get_model_path(model_name: str) -> str: + user_config = load_config() + return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) + + +def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: + checkpoints = [] + save_dir = os.path.join(get_save_dir(model_name), finetuning_type) + if save_dir and os.path.isdir(save_dir): + for checkpoint in os.listdir(save_dir): + if ( + os.path.isdir(os.path.join(save_dir, checkpoint)) + and any([ + os.path.isfile(os.path.join(save_dir, checkpoint, name)) + for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) + ]) + ): + checkpoints.append(checkpoint) + return gr.update(value=[], choices=checkpoints) + + +def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: + try: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + return json.load(f) + except: + return {} + + +def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: + dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) + return gr.update(value=[], choices=list(dataset_info.keys())) diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py new file mode 100644 index 00000000..779cf390 --- /dev/null +++ b/src/llmtuner/webui/components/__init__.py @@ -0,0 +1,4 @@ +from llmtuner.webui.components.eval import create_eval_tab +from llmtuner.webui.components.infer import create_infer_tab +from llmtuner.webui.components.top import create_top +from llmtuner.webui.components.sft import create_sft_tab diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py new file mode 100644 index 00000000..d56dd592 --- /dev/null +++ b/src/llmtuner/webui/components/chatbot.py @@ -0,0 +1,54 @@ +from typing import Dict, Tuple + +import gradio as gr +from gradio.blocks import Block +from gradio.components import Component + +from llmtuner.webui.chat import WebChatModel + + +def create_chat_box( + chat_model: WebChatModel +) -> Tuple[Block, Component, Component, Dict[str, Component]]: + with gr.Box(visible=False) as chat_box: + chatbot = gr.Chatbot() + + with gr.Row(): + with gr.Column(scale=4): + with gr.Column(scale=12): + query = gr.Textbox(show_label=False, lines=8) + + with gr.Column(min_width=32, scale=1): + submit_btn = gr.Button(variant="primary") + + with gr.Column(scale=1): + clear_btn = gr.Button() + max_new_tokens = gr.Slider( + 10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True + ) + top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True) + temperature = gr.Slider( + 0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True + ) + + history = gr.State([]) + + submit_btn.click( + chat_model.predict, + [chatbot, query, history, max_new_tokens, top_p, temperature], + [chatbot, history], + show_progress=True + ).then( + lambda: gr.update(value=""), outputs=[query] + ) + + clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) + + return chat_box, chatbot, history, dict( + query=query, + submit_btn=submit_btn, + clear_btn=clear_btn, + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature + ) diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py new file mode 100644 index 00000000..4445f39c --- /dev/null +++ b/src/llmtuner/webui/components/data.py @@ -0,0 +1,19 @@ +import gradio as gr +from gradio.blocks import Block +from gradio.components import Component +from typing import Tuple + + +def create_preview_box() -> Tuple[Block, Component, Component, Component]: + with gr.Box(visible=False, elem_classes="modal-box") as preview_box: + with gr.Row(): + preview_count = gr.Number(interactive=False) + + with gr.Row(): + preview_samples = gr.JSON(interactive=False) + + close_btn = gr.Button() + + close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box]) + + return preview_box, preview_count, preview_samples, close_btn diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py new file mode 100644 index 00000000..67407cbc --- /dev/null +++ b/src/llmtuner/webui/components/eval.py @@ -0,0 +1,60 @@ +from typing import Dict +import gradio as gr +from gradio.components import Component + +from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR +from llmtuner.webui.components.data import create_preview_box +from llmtuner.webui.runner import Runner +from llmtuner.webui.utils import can_preview, get_preview + + +def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: + with gr.Row(): + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) + dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) + preview_btn = gr.Button(interactive=False, scale=1) + + preview_box, preview_count, preview_samples, close_btn = create_preview_box() + + dataset_dir.change(list_dataset, [dataset_dir], [dataset]) + dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) + preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) + + with gr.Row(): + max_samples = gr.Textbox(value="100000", interactive=True) + batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True) + quantization_bit = gr.Dropdown([8, 4]) + predict = gr.Checkbox(value=True) + + with gr.Row(): + start_btn = gr.Button() + stop_btn = gr.Button() + + output_box = gr.Markdown() + + start_btn.click( + runner.run_eval, + [ + top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], + top_elems["finetuning_type"], top_elems["template"], + dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict + ], + [output_box] + ) + stop_btn.click(runner.set_abort, queue=False) + + return dict( + dataset_dir=dataset_dir, + dataset=dataset, + preview_btn=preview_btn, + preview_count=preview_count, + preview_samples=preview_samples, + close_btn=close_btn, + max_samples=max_samples, + batch_size=batch_size, + quantization_bit=quantization_bit, + predict=predict, + start_btn=start_btn, + stop_btn=stop_btn, + output_box=output_box + ) diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py new file mode 100644 index 00000000..831dbea5 --- /dev/null +++ b/src/llmtuner/webui/components/infer.py @@ -0,0 +1,47 @@ +from typing import Dict + +import gradio as gr +from gradio.components import Component + +from llmtuner.webui.chat import WebChatModel +from llmtuner.webui.components.chatbot import create_chat_box + + +def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: + with gr.Row(): + load_btn = gr.Button() + unload_btn = gr.Button() + quantization_bit = gr.Dropdown([8, 4]) + + info_box = gr.Markdown() + + chat_model = WebChatModel() + chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) + + load_btn.click( + chat_model.load_model, + [ + top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], + top_elems["finetuning_type"], top_elems["template"], + quantization_bit + ], + [info_box] + ).then( + lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] + ) + + unload_btn.click( + chat_model.unload_model, [top_elems["lang"]], [info_box] + ).then( + lambda: ([], []), outputs=[chatbot, history] + ).then( + lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] + ) + + return dict( + quantization_bit=quantization_bit, + info_box=info_box, + load_btn=load_btn, + unload_btn=unload_btn, + **chat_elems + ) diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/sft.py new file mode 100644 index 00000000..9740673a --- /dev/null +++ b/src/llmtuner/webui/components/sft.py @@ -0,0 +1,94 @@ +from typing import Dict +from transformers.trainer_utils import SchedulerType + +import gradio as gr +from gradio.components import Component + +from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR +from llmtuner.webui.components.data import create_preview_box +from llmtuner.webui.runner import Runner +from llmtuner.webui.utils import can_preview, get_preview, gen_plot + + +def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: + with gr.Row(): + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1) + dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) + preview_btn = gr.Button(interactive=False, scale=1) + + preview_box, preview_count, preview_samples, close_btn = create_preview_box() + + dataset_dir.change(list_dataset, [dataset_dir], [dataset]) + dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) + preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) + + with gr.Row(): + learning_rate = gr.Textbox(value="5e-5", interactive=True) + num_train_epochs = gr.Textbox(value="3.0", interactive=True) + max_samples = gr.Textbox(value="100000", interactive=True) + quantization_bit = gr.Dropdown([8, 4]) + + with gr.Row(): + batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True) + gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True) + lr_scheduler_type = gr.Dropdown( + value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True + ) + fp16 = gr.Checkbox(value=True) + + with gr.Row(): + logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True) + save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True) + + with gr.Row(): + start_btn = gr.Button() + stop_btn = gr.Button() + + with gr.Row(): + with gr.Column(scale=4): + output_dir = gr.Textbox(interactive=True) + output_box = gr.Markdown() + + with gr.Column(scale=1): + loss_viewer = gr.Plot() + + start_btn.click( + runner.run_train, + [ + top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], + top_elems["finetuning_type"], top_elems["template"], + dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, + fp16, quantization_bit, batch_size, gradient_accumulation_steps, + lr_scheduler_type, logging_steps, save_steps, output_dir + ], + [output_box] + ) + stop_btn.click(runner.set_abort, queue=False) + + output_box.change( + gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False + ) + + return dict( + dataset_dir=dataset_dir, + dataset=dataset, + preview_btn=preview_btn, + preview_count=preview_count, + preview_samples=preview_samples, + close_btn=close_btn, + learning_rate=learning_rate, + num_train_epochs=num_train_epochs, + max_samples=max_samples, + quantization_bit=quantization_bit, + batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + lr_scheduler_type=lr_scheduler_type, + fp16=fp16, + logging_steps=logging_steps, + save_steps=save_steps, + start_btn=start_btn, + stop_btn=stop_btn, + output_dir=output_dir, + output_box=output_box, + loss_viewer=loss_viewer + ) diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py new file mode 100644 index 00000000..8efb1d8d --- /dev/null +++ b/src/llmtuner/webui/components/top.py @@ -0,0 +1,42 @@ +from typing import Dict + +import gradio as gr +from gradio.components import Component + +from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS +from llmtuner.extras.template import templates +from llmtuner.webui.common import list_checkpoint, get_model_path, save_config + + +def create_top() -> Dict[str, Component]: + available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] + + with gr.Row(): + lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1) + model_name = gr.Dropdown(choices=available_models, scale=3) + model_path = gr.Textbox(scale=3) + + with gr.Row(): + finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) + template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1) + checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4) + refresh_btn = gr.Button(scale=1) + + model_name.change( + list_checkpoint, [model_name, finetuning_type], [checkpoints] + ).then( + get_model_path, [model_name], [model_path] + ) # do not save config since the below line will save + model_path.change(save_config, [model_name, model_path]) + finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints]) + refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) + + return dict( + lang=lang, + model_name=model_name, + model_path=model_path, + finetuning_type=finetuning_type, + template=template, + checkpoints=checkpoints, + refresh_btn=refresh_btn + ) diff --git a/src/llmtuner/webui/css.py b/src/llmtuner/webui/css.py new file mode 100644 index 00000000..5d370c1f --- /dev/null +++ b/src/llmtuner/webui/css.py @@ -0,0 +1,18 @@ +CSS = r""" +.modal-box { + position: fixed !important; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); /* center horizontally */ + max-width: 1000px; + max-height: 750px; + overflow-y: scroll !important; + background-color: var(--input-background-fill); + border: 2px solid black !important; + z-index: 1000; +} + +.dark .modal-box { + border: 2px solid white !important; +} +""" diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py new file mode 100644 index 00000000..fc36478a --- /dev/null +++ b/src/llmtuner/webui/interface.py @@ -0,0 +1,54 @@ +import gradio as gr +from transformers.utils.versions import require_version + +from llmtuner.webui.components import ( + create_top, + create_sft_tab, + create_eval_tab, + create_infer_tab +) +from llmtuner.webui.css import CSS +from llmtuner.webui.manager import Manager +from llmtuner.webui.runner import Runner + + +require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") + + +def create_ui() -> gr.Blocks: + runner = Runner() + + with gr.Blocks(title="Web Tuner", css=CSS) as demo: + top_elems = create_top() + + with gr.Tab("SFT"): + sft_elems = create_sft_tab(top_elems, runner) + + with gr.Tab("Evaluate"): + eval_elems = create_eval_tab(top_elems, runner) + + with gr.Tab("Inference"): + infer_elems = create_infer_tab(top_elems) + + elem_list = [top_elems, sft_elems, eval_elems, infer_elems] + manager = Manager(elem_list) + + demo.load( + manager.gen_label, + [top_elems["lang"]], + [elem for elems in elem_list for elem in elems.values()], + ) + + top_elems["lang"].change( + manager.gen_label, + [top_elems["lang"]], + [elem for elems in elem_list for elem in elems.values()], + ) + + return demo + + +if __name__ == "__main__": + demo = create_ui() + demo.queue() + demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py new file mode 100644 index 00000000..2bc9b3d0 --- /dev/null +++ b/src/llmtuner/webui/locales.py @@ -0,0 +1,384 @@ +LOCALES = { + "lang": { + "en": { + "label": "Lang" + }, + "zh": { + "label": "语言" + } + }, + "model_name": { + "en": { + "label": "Model name" + }, + "zh": { + "label": "模型名称" + } + }, + "model_path": { + "en": { + "label": "Model path", + "info": "Path to pretrained model or model identifier from Hugging Face." + }, + "zh": { + "label": "模型路径", + "info": "本地模型的文件路径或 Hugging Face 的模型标识符。" + } + }, + "checkpoints": { + "en": { + "label": "Checkpoints" + }, + "zh": { + "label": "模型断点" + } + }, + "template": { + "en": { + "label": "Prompt template" + }, + "zh": { + "label": "提示模板" + } + }, + "refresh_btn": { + "en": { + "value": "Refresh checkpoints" + }, + "zh": { + "value": "刷新断点" + } + }, + "dataset_dir": { + "en": { + "label": "Data dir", + "info": "Path of the data directory." + }, + "zh": { + "label": "数据路径", + "info": "数据文件夹的路径。" + } + }, + "dataset": { + "en": { + "label": "Dataset" + }, + "zh": { + "label": "数据集" + } + }, + "preview_btn": { + "en": { + "value": "Preview" + }, + "zh": { + "value": "预览" + } + }, + "preview_count": { + "en": { + "label": "Count" + }, + "zh": { + "label": "数量" + } + }, + "preview_samples": { + "en": { + "label": "Samples" + }, + "zh": { + "label": "样例" + } + }, + "close_btn": { + "en": { + "value": "Close" + }, + "zh": { + "value": "关闭" + } + }, + "max_samples": { + "en": { + "label": "Max samples", + "info": "Maximum samples per dataset." + }, + "zh": { + "label": "最大样本数", + "info": "每个数据集最多使用的样本数。" + } + }, + "batch_size": { + "en": { + "label": "Batch size", + "info": "Number of samples to process per GPU." + }, + "zh":{ + "label": "批处理大小", + "info": "每块 GPU 上处理的样本数量。" + } + }, + "quantization_bit": { + "en": { + "label": "Quantization bit", + "info": "Enable 4/8-bit model quantization." + }, + "zh": { + "label": "量化", + "info": "启用 4/8 比特模型量化。" + } + }, + "start_btn": { + "en": { + "value": "Start" + }, + "zh": { + "value": "开始" + } + }, + "stop_btn": { + "en": { + "value": "Abort" + }, + "zh": { + "value": "中断" + } + }, + "output_box": { + "en": { + "value": "Ready." + }, + "zh": { + "value": "准备就绪。" + } + }, + "finetuning_type": { + "en": { + "label": "Finetuning method" + }, + "zh": { + "label": "微调方法" + } + }, + "learning_rate": { + "en": { + "label": "Learning rate", + "info": "Initial learning rate for AdamW." + }, + "zh": { + "label": "学习率", + "info": "AdamW 优化器的初始学习率。" + } + }, + "num_train_epochs": { + "en": { + "label": "Epochs", + "info": "Total number of training epochs to perform." + }, + "zh": { + "label": "训练轮数", + "info": "需要执行的训练总轮数。" + } + }, + "gradient_accumulation_steps": { + "en": { + "label": "Gradient accumulation", + "info": "Number of gradient accumulation steps." + }, + "zh": { + "label": "梯度累积", + "info": "梯度累积的步数。" + } + }, + "lr_scheduler_type": { + "en": { + "label": "LR Scheduler", + "info": "Name of learning rate scheduler.", + }, + "zh": { + "label": "学习率调节器", + "info": "采用的学习率调节器名称。" + } + }, + "fp16": { + "en": { + "label": "fp16", + "info": "Whether to use fp16 mixed precision training." + }, + "zh": { + "label": "fp16", + "info": "是否启用 FP16 混合精度训练。" + } + }, + "logging_steps": { + "en": { + "label": "Logging steps", + "info": "Number of update steps between two logs." + }, + "zh": { + "label": "日志间隔", + "info": "每两次日志输出间的更新步数。" + } + }, + "save_steps": { + "en": { + "label": "Save steps", + "info": "Number of updates steps between two checkpoints." + }, + "zh": { + "label": "保存间隔", + "info": "每两次断点保存间的更新步数。" + } + }, + "output_dir": { + "en": { + "label": "Checkpoint name", + "info": "Directory to save checkpoint." + }, + "zh": { + "label": "断点名称", + "info": "保存模型断点的文件夹名称。" + } + }, + "loss_viewer": { + "en": { + "label": "Loss" + }, + "zh": { + "label": "损失" + } + }, + "predict": { + "en": { + "label": "Save predictions" + }, + "zh": { + "label": "保存预测结果" + } + }, + "info_box": { + "en": { + "value": "Model unloaded, please load a model first." + }, + "zh": { + "value": "模型未加载,请先加载模型。" + } + }, + "load_btn": { + "en": { + "value": "Load model" + }, + "zh": { + "value": "加载模型" + } + }, + "unload_btn": { + "en": { + "value": "Unload model" + }, + "zh": { + "value": "卸载模型" + } + }, + "query": { + "en": { + "placeholder": "Input..." + }, + "zh": { + "placeholder": "输入..." + } + }, + "submit_btn": { + "en": { + "value": "Submit" + }, + "zh": { + "value": "提交" + } + }, + "clear_btn": { + "en": { + "value": "Clear history" + }, + "zh": { + "value": "清空历史" + } + }, + "max_new_tokens": { + "en": { + "label": "Maximum new tokens" + }, + "zh": { + "label": "最大生成长度" + } + }, + "top_p": { + "en": { + "label": "Top-p" + }, + "zh": { + "label": "Top-p 采样值" + } + }, + "temperature": { + "en": { + "label": "Temperature" + }, + "zh": { + "label": "温度系数" + } + } +} + + +ALERTS = { + "err_conflict": { + "en": "A process is in running, please abort it firstly.", + "zh": "任务已存在,请先中断训练。" + }, + "err_exists": { + "en": "You have loaded a model, please unload it first.", + "zh": "模型已存在,请先卸载模型。" + }, + "err_no_model": { + "en": "Please select a model.", + "zh": "请选择模型。" + }, + "err_no_path": { + "en": "Model not found.", + "zh": "模型未找到。" + }, + "err_no_dataset": { + "en": "Please choose a dataset.", + "zh": "请选择数据集。" + }, + "info_aborting": { + "en": "Aborted, wait for terminating...", + "zh": "训练中断,正在等待线程结束……" + }, + "info_aborted": { + "en": "Ready.", + "zh": "准备就绪。" + }, + "info_finished": { + "en": "Finished.", + "zh": "训练完毕。" + }, + "info_loading": { + "en": "Loading model...", + "zh": "加载中……" + }, + "info_unloading": { + "en": "Unloading model...", + "zh": "卸载中……" + }, + "info_loaded": { + "en": "Model loaded, now you can chat with your model!", + "zh": "模型已加载,可以开始聊天了!" + }, + "info_unloaded": { + "en": "Model unloaded.", + "zh": "模型已卸载。" + } +} diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py new file mode 100644 index 00000000..28c40cad --- /dev/null +++ b/src/llmtuner/webui/manager.py @@ -0,0 +1,35 @@ +import gradio as gr +from typing import Any, Dict, List +from gradio.components import Component + +from llmtuner.webui.common import get_model_path, list_dataset, load_config +from llmtuner.webui.locales import LOCALES +from llmtuner.webui.utils import get_time + + +class Manager: + + def __init__(self, elem_list: List[Dict[str, Component]]): + self.elem_list = elem_list + + def gen_refresh(self) -> Dict[str, Any]: + refresh_dict = { + "dataset": {"choices": list_dataset()["choices"]}, + "output_dir": {"value": get_time()} + } + user_config = load_config() + if user_config["last_model"]: + refresh_dict["model_name"] = {"value": user_config["last_model"]} + refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} + + return refresh_dict + + def gen_label(self, lang: str) -> Dict[Component, dict]: + update_dict = {} + refresh_dict = self.gen_refresh() + + for elems in self.elem_list: + for name, component in elems.items(): + update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) + + return update_dict diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py new file mode 100644 index 00000000..599d31c3 --- /dev/null +++ b/src/llmtuner/webui/runner.py @@ -0,0 +1,177 @@ +import logging +import os +import threading +import time +import transformers +from typing import Optional, Tuple + +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.logging import LoggerHandler +from llmtuner.extras.misc import torch_gc +from llmtuner.tuner import get_train_args, run_sft +from llmtuner.webui.common import get_model_path, get_save_dir +from llmtuner.webui.locales import ALERTS +from llmtuner.webui.utils import format_info, get_eval_results + + +class Runner: + + def __init__(self): + self.aborted = False + self.running = False + + def set_abort(self): + self.aborted = True + self.running = False + + def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: + if self.running: + return None, ALERTS["err_conflict"][lang], None, None + + if not model_name: + return None, ALERTS["err_no_model"][lang], None, None + + model_name_or_path = get_model_path(model_name) + if not model_name_or_path: + return None, ALERTS["err_no_path"][lang], None, None + + if len(dataset) == 0: + return None, ALERTS["err_no_dataset"][lang], None, None + + self.aborted = False + self.running = True + + logger_handler = LoggerHandler() + logger_handler.setLevel(logging.INFO) + logging.root.addHandler(logger_handler) + transformers.logging.add_handler(logger_handler) + trainer_callback = LogCallback(self) + + return model_name_or_path, "", logger_handler, trainer_callback + + def finalize(self, lang: str, finish_info: Optional[str] = None) -> str: + self.running = False + torch_gc() + if self.aborted: + return ALERTS["info_aborted"][lang] + else: + return finish_info if finish_info is not None else ALERTS["info_finished"][lang] + + def run_train( + self, lang, model_name, checkpoints, finetuning_type, template, + dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, + fp16, quantization_bit, batch_size, gradient_accumulation_steps, + lr_scheduler_type, logging_steps, save_steps, output_dir + ): + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) + if error: + yield error + return + + if checkpoints: + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) + else: + checkpoint_dir = None + + args = dict( + model_name_or_path=model_name_or_path, + do_train=True, + finetuning_type=finetuning_type, + prompt_template=template, + dataset=",".join(dataset), + dataset_dir=dataset_dir, + max_samples=int(max_samples), + output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir), + checkpoint_dir=checkpoint_dir, + overwrite_cache=True, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + lr_scheduler_type=lr_scheduler_type, + logging_steps=logging_steps, + save_steps=save_steps, + learning_rate=float(learning_rate), + num_train_epochs=float(num_train_epochs), + fp16=fp16, + quantization_bit=int(quantization_bit) if quantization_bit else None + ) + model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) + + run_args = dict( + model_args=model_args, + data_args=data_args, + training_args=training_args, + finetuning_args=finetuning_args, + callbacks=[trainer_callback] + ) + thread = threading.Thread(target=run_sft, kwargs=run_args) + thread.start() + + while thread.is_alive(): + time.sleep(1) + if self.aborted: + yield ALERTS["info_aborting"][lang] + else: + yield format_info(logger_handler.log, trainer_callback.tracker) + + yield self.finalize(lang) + + def run_eval( + self, lang, model_name, checkpoints, finetuning_type, template, + dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict + ): + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) + if error: + yield error + return + + if checkpoints: + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints)) + else: + checkpoint_dir = None + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") + + args = dict( + model_name_or_path=model_name_or_path, + do_eval=True, + finetuning_type=finetuning_type, + prompt_template=template, + dataset=",".join(dataset), + dataset_dir=dataset_dir, + max_samples=int(max_samples), + output_dir=output_dir, + checkpoint_dir=checkpoint_dir, + overwrite_cache=True, + predict_with_generate=True, + per_device_eval_batch_size=batch_size, + quantization_bit=int(quantization_bit) if quantization_bit else None + ) + + if predict: + args.pop("do_eval", None) + args["do_predict"] = True + + model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) + + run_args = dict( + model_args=model_args, + data_args=data_args, + training_args=training_args, + finetuning_args=finetuning_args, + callbacks=[trainer_callback] + ) + thread = threading.Thread(target=run_sft, kwargs=run_args) + thread.start() + + while thread.is_alive(): + time.sleep(1) + if self.aborted: + yield ALERTS["info_aborting"][lang] + else: + yield format_info(logger_handler.log, trainer_callback.tracker) + + yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json"))) diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py new file mode 100644 index 00000000..506ea4b0 --- /dev/null +++ b/src/llmtuner/webui/utils.py @@ -0,0 +1,74 @@ +import os +import json +import gradio as gr +import matplotlib.figure +import matplotlib.pyplot as plt +from typing import Tuple +from datetime import datetime + +from llmtuner.extras.ploting import smooth +from llmtuner.webui.common import get_save_dir, DATA_CONFIG + + +def format_info(log: str, tracker: dict) -> str: + info = log + if "current_steps" in tracker: + info += "Running **{:d}/{:d}**: {} < {}\n".format( + tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"] + ) + return info + + +def get_time() -> str: + return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + + +def can_preview(dataset_dir: str, dataset: list) -> dict: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + if ( + len(dataset) > 0 + and "file_name" in dataset_info[dataset[0]] + and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) + ): + return gr.update(interactive=True) + else: + return gr.update(interactive=False) + + +def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + data_file = dataset_info[dataset[0]]["file_name"] + with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: + data = json.load(f) + return len(data), data[:2], gr.update(visible=True) + + +def get_eval_results(path: os.PathLike) -> str: + with open(path, "r", encoding="utf-8") as f: + result = json.dumps(json.load(f), indent=4) + return "```json\n{}\n```\n".format(result) + + +def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: + log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl") + if not os.path.isfile(log_file): + return None + + plt.close("all") + fig = plt.figure() + ax = fig.add_subplot(111) + steps, losses = [], [] + with open(log_file, "r", encoding="utf-8") as f: + for line in f: + log_info = json.loads(line) + if log_info.get("loss", None): + steps.append(log_info["current_steps"]) + losses.append(log_info["loss"]) + ax.plot(steps, losses, alpha=0.4, label="original") + ax.plot(steps, smooth(losses), label="smoothed") + ax.legend() + ax.set_xlabel("step") + ax.set_ylabel("loss") + return fig diff --git a/src/train_web.py b/src/train_web.py new file mode 100644 index 00000000..3f7855c0 --- /dev/null +++ b/src/train_web.py @@ -0,0 +1,11 @@ +from llmtuner import create_ui + + +def main(): + demo = create_ui() + demo.queue() + demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) + + +if __name__ == "__main__": + main() diff --git a/src/web_demo.py b/src/web_demo.py deleted file mode 100644 index c60e4138..00000000 --- a/src/web_demo.py +++ /dev/null @@ -1,95 +0,0 @@ -# coding=utf-8 -# Implements user interface in browser for fine-tuned models. -# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint - -import gradio as gr -from threading import Thread -from transformers import TextIteratorStreamer -from transformers.utils.versions import require_version - -from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor - - -require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") - - -model_args, data_args, finetuning_args, generating_args = get_infer_args() -model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - -prompt_template = Template(data_args.prompt_template) -source_prefix = data_args.source_prefix if data_args.source_prefix else "" - - -def predict(query, chatbot, max_new_tokens, top_p, temperature, history): - chatbot.append((query, "")) - - input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] - input_ids = input_ids.to(model.device) - - streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - - gen_kwargs = generating_args.to_dict() - gen_kwargs.update({ - "input_ids": input_ids, - "top_p": top_p, - "temperature": temperature, - "max_new_tokens": max_new_tokens, - "logits_processor": get_logits_processor(), - "streamer": streamer - }) - - thread = Thread(target=model.generate, kwargs=gen_kwargs) - thread.start() - - response = "" - for new_text in streamer: - response += new_text - new_history = history + [(query, response)] - chatbot[-1] = (query, response) - yield chatbot, new_history - - -def reset_user_input(): - return gr.update(value="") - - -def reset_state(): - return [], [] - - -with gr.Blocks() as demo: - - gr.HTML(""" -

- - LLaMA Efficient Tuning - -

- """) - - chatbot = gr.Chatbot() - - with gr.Row(): - with gr.Column(scale=4): - with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False) - with gr.Column(min_width=32, scale=1): - submitBtn = gr.Button("Submit", variant="primary") - - with gr.Column(scale=1): - emptyBtn = gr.Button("Clear History") - max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0, - label="Maximum new tokens", interactive=True) - top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01, - label="Top P", interactive=True) - temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01, - label="Temperature", interactive=True) - - history = gr.State([]) - - submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True) - submitBtn.click(reset_user_input, [], [user_input]) - - emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) - -demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)