From 421d4de604493e1e26ec8348dab3eae138f46b86 Mon Sep 17 00:00:00 2001 From: samge Date: Fri, 1 Dec 2023 11:35:02 +0800 Subject: [PATCH] =?UTF-8?q?Improve=EF=BC=9A"CUDA=5FVISIBLE=5FDEVICES"=20re?= =?UTF-8?q?ad=20from=20the=20env?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llmtuner/webui/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 7c624db4..fc2c5c2c 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -44,7 +44,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: args.pop("disable_tqdm", None) args["plot_loss"] = args.get("do_train", None) - cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "] + cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') or "0" + cmd_lines = [f"CUDA_VISIBLE_DEVICES={cuda_visible_devices} python src/train_bash.py "] for k, v in args.items(): if v is not None and v != "": cmd_lines.append(" --{} {} ".format(k, str(v)))