From 26817143ff86a853c011be11678235bcc803ccce Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Sat, 16 Dec 2023 05:16:29 +0000 Subject: [PATCH 1/2] Improve logging for unknown args --- src/llmtuner/extras/misc.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index b8424d62..4f123e14 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -5,6 +5,9 @@ import torch from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList +import logging +logger = logging.getLogger(__name__) + try: from transformers.utils import ( is_torch_bf16_cpu_available, @@ -111,7 +114,12 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): return parser.parse_json_file(os.path.abspath(sys.argv[1])) else: - return parser.parse_args_into_dataclasses() + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + if unknown_args: + logger.warning(parser.format_help()) + logger.error(f'\nGot unknown args, potentially deprecated arguments: {unknown_args}\n') + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + return (*parsed_args,) def torch_gc() -> None: From fc70a92cb6e9c22bab9a0695f476ae80461c656f Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Sat, 16 Dec 2023 07:15:27 +0000 Subject: [PATCH 2/2] Use llmtuner logger --- src/llmtuner/extras/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 4f123e14..9a50e369 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -5,8 +5,8 @@ import torch from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList -import logging -logger = logging.getLogger(__name__) +from llmtuner.extras.logging import get_logger +logger = get_logger(__name__) try: from transformers.utils import (