Merge pull request #975 from statelesshz/npu-support

Add Ascend NPU support
This commit is contained in:
hoshi-hiyouga 2023-09-20 14:56:50 +08:00 committed by GitHub
commit ac8648b431
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions

View File

@ -13,7 +13,7 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.utils import check_min_version
from transformers.utils import check_min_version, is_torch_npu_available
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
@ -215,7 +215,10 @@ def load_model_and_tokenizer(
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
if is_torch_npu_available():
infer_dtype = torch.float16
else:
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model)