Merge pull request #3987 from injet-zhou/main

Fix cann't interrupt training when using multi GPUs in webui
This commit is contained in:
hoshi-hiyouga 2024-06-04 00:04:07 +08:00 committed by GitHub
commit ae18e1e251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 31 additions and 3 deletions

View File

@ -1,5 +1,6 @@
import os
import signal
import random
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
@ -8,12 +9,12 @@ import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.misc import is_gpu_or_npu_available, torch_gc, get_device_count
from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
from .. import launcher
if is_gradio_available():
import gradio as gr
@ -41,6 +42,10 @@ class Runner:
self.aborted = True
if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
grand_children = children.children()
if len(grand_children) > 0:
for grand_child in grand_children:
os.kill(grand_child.pid, signal.SIGABRT)
os.kill(children.pid, signal.SIGABRT)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
@ -280,6 +285,29 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1"
if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
self.trainer = Popen([
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
save_cmd(args)
], env=env, shell=True)
else:
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor()