support DDP in webui

This commit is contained in:
hiyouga 2024-05-28 19:24:22 +08:00
parent c8765349ba
commit 7c016b22aa
19 changed files with 78 additions and 166 deletions

View File

@ -110,19 +110,20 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_l
#### Supervised Fine-Tuning with Accelerate on Single Node
```bash
bash examples/lora_multi_gpu/single_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
```bash
bash examples/lora_multi_gpu/multi_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash
bash examples/lora_multi_gpu/ds_zero3.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
```
### LoRA Fine-Tuning on Multiple NPUs
@ -130,7 +131,7 @@ bash examples/lora_multi_gpu/ds_zero3.sh
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
```bash
bash examples/lora_multi_npu/ds_zero0.sh
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
```
### Full-Parameter Fine-Tuning on Multiple GPUs
@ -138,19 +139,20 @@ bash examples/lora_multi_npu/ds_zero0.sh
#### Supervised Fine-Tuning with Accelerate on Single Node
```bash
bash examples/full_multi_gpu/single_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
```bash
bash examples/full_multi_gpu/multi_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
bash examples/full_multi_gpu/predict.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
```
### Merging LoRA Adapters and Quantization

View File

@ -110,19 +110,20 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_l
#### 使用 Accelerate 进行单节点训练
```bash
bash examples/lora_multi_gpu/single_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
```
#### 使用 Accelerate 进行多节点训练
```bash
bash examples/lora_multi_gpu/multi_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
```
#### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash
bash examples/lora_multi_gpu/ds_zero3.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
```
### 多 NPU LoRA 微调
@ -130,7 +131,7 @@ bash examples/lora_multi_gpu/ds_zero3.sh
#### 使用 DeepSpeed ZeRO-0 训练
```bash
bash examples/lora_multi_npu/ds_zero0.sh
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
```
### 多 GPU 全参数微调
@ -138,19 +139,20 @@ bash examples/lora_multi_npu/ds_zero0.sh
#### 使用 DeepSpeed 进行单节点训练
```bash
bash examples/full_multi_gpu/single_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
```
#### 使用 DeepSpeed 进行多节点训练
```bash
bash examples/full_multi_gpu/multi_node.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
bash examples/full_multi_gpu/predict.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
```
### 合并 LoRA 适配器与模型量化

View File

@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,16 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1 # the number of nodes
num_processes: 4 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 1
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=2
RANK=0
MASTER_ADDR=192.168.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/full_multi_gpu/llama3_full_predict.yaml

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml

View File

@ -1,6 +0,0 @@
#!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/master_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml

View File

@ -1,9 +1,16 @@
import os
import random
import subprocess
import sys
from enum import Enum, unique
from llamafactory import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
@ -37,6 +44,8 @@ WELCOME = (
+ "-" * 58
)
logger = get_logger(__name__)
@unique
class Command(str, Enum):
@ -62,7 +71,32 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
run_exp()
if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
[
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
*sys.argv[1:],
]
)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:

View File

@ -165,13 +165,15 @@ def get_current_device() -> torch.device:
def get_device_count() -> int:
r"""
Gets the number of available GPU devices.
Gets the number of available GPU or NPU devices.
"""
if not torch.cuda.is_available():
if is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
else:
return 0
return torch.cuda.device_count()
def get_logits_processor() -> "LogitsProcessorList":
r"""
@ -194,6 +196,13 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
return is_torch_npu_available() or is_torch_cuda_available()
def has_tokenized_data(path: os.PathLike) -> bool:
r"""
Checks if the path has a tokenized dataset.

View File

@ -10,7 +10,6 @@ from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.constants import TRAINER_CONFIG
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
@ -252,10 +251,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
files = os.listdir(training_args.output_dir)
if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(

View File

@ -0,0 +1,9 @@
from llamafactory.train.tuner import run_exp
def launch():
run_exp()
if __name__ == "__main__":
launch()

View File

@ -1469,11 +1469,6 @@ ALERTS = {
"ru": "Обучение недоступно в демонстрационном режиме, сначала скопируйте пространство в частное.",
"zh": "展示模式不支持训练,请先复制到私人空间。",
},
"err_device_count": {
"en": "Multiple GPUs are not supported yet.",
"ru": "Пока не поддерживается множественные GPU.",
"zh": "尚不支持多 GPU 训练。",
},
"err_tool_name": {
"en": "Tool name not found.",
"ru": "Имя инструмента не найдено.",

View File

@ -6,10 +6,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available
from ..extras.constants import TRAINING_STAGES
from ..extras.misc import get_device_count, torch_gc
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
@ -64,16 +63,13 @@ class Runner:
if not from_preview and self.demo_mode:
return ALERTS["err_demo"][lang]
if not from_preview and get_device_count() > 1:
return ALERTS["err_device_count"][lang]
if do_train:
stage = TRAINING_STAGES[get("train.training_stage")]
reward_model = get("train.reward_model")
if stage == "ppo" and not reward_model:
return ALERTS["err_no_reward_model"][lang]
if not from_preview and not is_torch_cuda_available():
if not from_preview and not is_gpu_or_npu_available():
gr.Warning(ALERTS["warn_no_cuda"][lang])
return ""
@ -273,7 +269,6 @@ class Runner:
self.do_train, self.running_data = do_train, data
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ)
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
env["LLAMABOARD_ENABLED"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor()

View File

@ -42,8 +42,7 @@ def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
def gen_cmd(args: Dict[str, Any]) -> str:
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)]
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))