fix abort in webui DDP mode

This commit is contained in:
hiyouga 2024-06-04 00:10:24 +08:00
parent ae18e1e251
commit 2187518762
3 changed files with 32 additions and 54 deletions

View File

@ -71,28 +71,23 @@ def main():
export_model()
elif command == Command.TRAIN:
if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
[
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
*sys.argv[1:],
]
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
else:
run_exp()

View File

@ -1,20 +1,17 @@
import os
import signal
import random
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, get_device_count
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
from .. import launcher
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available():
import gradio as gr
@ -41,12 +38,7 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
grand_children = children.children()
if len(grand_children) > 0:
for grand_child in grand_children:
os.kill(grand_child.pid, signal.SIGABRT)
os.kill(children.pid, signal.SIGABRT)
abort_leaf_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@ -285,30 +277,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1"
if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
self.trainer = Popen([
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
save_cmd(args)
], env=env, shell=True)
else:
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor()
def preview_train(self, data):

View File

@ -1,8 +1,10 @@
import json
import os
import signal
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import psutil
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
@ -17,6 +19,18 @@ if is_gradio_available():
import gradio as gr
def abort_leaf_process(pid: int) -> None:
r"""
Aborts the leaf processes.
"""
children = psutil.Process(pid).children()
if children:
for child in children:
abort_leaf_process(child.pid)
else:
os.kill(pid, signal.SIGABRT)
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.