release v0.1.4

This commit is contained in:
hiyouga 2023-08-01 10:08:47 +08:00
parent d3a0692d4d
commit 973a638665
13 changed files with 66 additions and 49 deletions

View File

@ -14,6 +14,8 @@
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
@ -32,7 +34,7 @@
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
@ -312,8 +314,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
If you want to predict the samples with empty responses, please kindly fill the `response` column with **dummy tokens** to ensure the sample will not be discarded throughout the preprocessing phase.
### API Demo
```bash
@ -373,7 +373,7 @@ Please follow the model licenses to use the corresponding model weights:
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
## Citation

View File

@ -14,6 +14,8 @@
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
@ -312,8 +314,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
如果需要预测的样本没有标签,请首先在 `response` 列中填入一些占位符,以免样本在预处理阶段被丢弃。
### API 服务
```bash
@ -373,7 +373,7 @@ python src/export_model.py \
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
## 引用

View File

@ -1,4 +1,4 @@
from llmtuner.chat import ChatModel
__version__ = "0.1.3"
__version__ = "0.1.4"

View File

@ -93,9 +93,11 @@ def load_model_and_tokenizer(
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:

View File

@ -32,9 +32,9 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
))
return _parse_args(parser, args)
@ -51,7 +51,7 @@ def parse_infer_args(
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@ -79,6 +79,12 @@ def get_train_args(
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
assert not (training_args.max_steps == -1 and data_args.streaming), \
"Please specify `max_steps` in streaming mode."
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
@ -108,12 +114,6 @@ def get_train_args(
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
assert not (training_args.max_steps == -1 and data_args.streaming), \
"Please specify `max_steps` in streaming mode."
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:

View File

@ -1,16 +1,17 @@
from typing import Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel,
chat_model: "WebChatModel",
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()

View File

@ -1,10 +1,12 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)

View File

@ -1,14 +1,16 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)

View File

@ -1,11 +1,13 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.utils import export_model
if TYPE_CHECKING:
from gradio.components import Component
def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)

View File

@ -1,13 +1,15 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()

View File

@ -1,16 +1,18 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)

View File

@ -1,15 +1,17 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, Component]:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():

View File

@ -1,15 +1,17 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from typing import TYPE_CHECKING, Any, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.utils import get_time
if TYPE_CHECKING:
from gradio.components import Component
class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
def __init__(self, elem_list: List[Dict[str, "Component"]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
@ -24,7 +26,7 @@ class Manager:
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
def gen_label(self, lang: str) -> Dict["Component", dict]:
update_dict = {}
refresh_dict = self.gen_refresh()