diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 7c624db4..fc2c5c2c 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -44,7 +44,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: args.pop("disable_tqdm", None) args["plot_loss"] = args.get("do_train", None) - cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "] + cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') or "0" + cmd_lines = [f"CUDA_VISIBLE_DEVICES={cuda_visible_devices} python src/train_bash.py "] for k, v in args.items(): if v is not None and v != "": cmd_lines.append(" --{} {} ".format(k, str(v)))