From b3e41c6d49728e239b61d8cd9b3603c8dc877549 Mon Sep 17 00:00:00 2001 From: statelesshz Date: Wed, 20 Sep 2023 10:15:59 +0800 Subject: [PATCH] support export model on Ascend NPU --- src/llmtuner/tuner/core/loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 2570c5d7..6cc40a33 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -13,7 +13,7 @@ from transformers import ( PreTrainedModel, PreTrainedTokenizerBase ) -from transformers.utils import check_min_version +from transformers.utils import check_min_version, is_torch_npu_available from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead @@ -215,7 +215,10 @@ def load_model_and_tokenizer( # Prepare model for inference if not is_trainable: model.requires_grad_(False) # fix all model params - infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability + if is_torch_npu_available(): + infer_dtype = torch.float16 + else: + infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability model = model.to(infer_dtype) if model_args.quantization_bit is None else model trainable_params, all_param = count_parameters(model)