From c874e764b8334c18091233c3781009a39d67e794 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 14 Oct 2023 20:15:24 +0800 Subject: [PATCH] fix loading dtype --- src/llmtuner/tuner/core/loader.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index d7827000..8f35183c 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -88,11 +88,10 @@ def load_model_and_tokenizer( tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) # Set model dtype - if model_args.compute_dtype is not None: + if model_args.compute_dtype is not None: # for training setattr(config, "torch_dtype", model_args.compute_dtype) - else: # priority: bf16 > fp16 > fp32 - optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - setattr(config, "torch_dtype", optim_dtype) + else: # for evaluation, priority: bf16 > fp16 > fp32 + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) # Fix config (for Qwen) if getattr(config, "model_type", None) == "qwen": @@ -185,7 +184,7 @@ def load_model_and_tokenizer( model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, - torch_dtype=getattr(config, "torch_dtype"), + torch_dtype=model_args.compute_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs )