Merge pull request #3987 from injet-zhou/main

Fix cann't interrupt training when using multi GPUs in webui
This commit is contained in:
hoshi-hiyouga 2024-06-04 00:04:07 +08:00 committed by GitHub
commit ae18e1e251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 31 additions and 3 deletions

View File

@ -1,5 +1,6 @@
import os import os
import signal import signal
import random
from copy import deepcopy from copy import deepcopy
from subprocess import Popen, TimeoutExpired from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
@ -8,12 +9,12 @@ import psutil
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc, get_device_count
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
from .. import launcher
if is_gradio_available(): if is_gradio_available():
import gradio as gr import gradio as gr
@ -41,6 +42,10 @@ class Runner:
self.aborted = True self.aborted = True
if self.trainer is not None: if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process for children in psutil.Process(self.trainer.pid).children(): # abort the child process
grand_children = children.children()
if len(grand_children) > 0:
for grand_child in grand_children:
os.kill(grand_child.pid, signal.SIGABRT)
os.kill(children.pid, signal.SIGABRT) os.kill(children.pid, signal.SIGABRT)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
@ -280,7 +285,30 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ) env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1" env["LLAMABOARD_ENABLED"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
self.trainer = Popen([
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
save_cmd(args)
], env=env, shell=True)
else:
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor() yield from self.monitor()
def preview_train(self, data): def preview_train(self, data):