diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 2570c5d7..6cc40a33 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -13,7 +13,7 @@ from transformers import ( PreTrainedModel, PreTrainedTokenizerBase ) -from transformers.utils import check_min_version +from transformers.utils import check_min_version, is_torch_npu_available from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead @@ -215,7 +215,10 @@ def load_model_and_tokenizer( # Prepare model for inference if not is_trainable: model.requires_grad_(False) # fix all model params - infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability + if is_torch_npu_available(): + infer_dtype = torch.float16 + else: + infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability model = model.to(infer_dtype) if model_args.quantization_bit is None else model trainable_params, all_param = count_parameters(model)