update training resuming

This commit is contained in:
hiyouga 2023-08-18 01:41:17 +08:00
parent 7926432d27
commit 58f13e22da
7 changed files with 57 additions and 29 deletions

View File

@ -12,6 +12,8 @@
## Changelog
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
@ -158,6 +160,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
Currently the web UI only supports training on **a single GPU**.
### Pre-Training

View File

@ -12,6 +12,8 @@
## 更新日志
[23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能。
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
@ -24,7 +26,7 @@
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的**一体化浏览器界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
@ -152,12 +154,14 @@ pip install -r requirements.txt
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### 浏览器一键微调/测试
### 浏览器一体化界面
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
目前网页 UI 仅支持**单卡训练**。
### 预训练
@ -451,6 +455,8 @@ python src/export_model.py \
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## 引用

View File

@ -121,6 +121,9 @@ def load_model_and_tokenizer(
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
@ -144,7 +147,7 @@ def load_model_and_tokenizer(
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)

View File

@ -5,6 +5,7 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
@ -97,30 +98,33 @@ def get_train_args(
if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stage can only be performed at training.")
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
@ -134,9 +138,15 @@ def get_train_args(
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
# postprocess data_args
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
# postprocess training_args
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
@ -145,12 +155,26 @@ def get_train_args(
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
if training_args.optim == "adamw_hf":
training_args.optim = "adamw_torch" # suppress warning
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)
# postprocess model_args
if training_args.bf16:
if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.")

View File

@ -50,7 +50,7 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train()
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

View File

@ -39,7 +39,7 @@ def run_pt(
# Training
if training_args.do_train:
train_result = trainer.train()
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

View File

@ -1,5 +1,4 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
import os
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
@ -11,14 +10,11 @@ from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import reset_logging, get_logger
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
@ -62,12 +58,7 @@ def run_sft(
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()