This commit is contained in:
hiyouga 2023-09-22 15:00:48 +08:00
parent 044d4425b4
commit 5cc7a44784
3 changed files with 18 additions and 12 deletions

View File

@ -369,8 +369,7 @@ python src/export_model.py \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export \
--fp16
--output_dir path_to_export
```
### API Demo

View File

@ -368,8 +368,7 @@ python src/export_model.py \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export \
--fp16
--output_dir path_to_export
```
### API 服务

View File

@ -9,10 +9,12 @@ from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
try:
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
is_fp16_available = is_torch_cuda_available()
is_bf16_available = is_torch_bf16_gpu_available()
is_npu_available = is_torch_npu_available()
except ImportError:
is_fp16_available = torch.cuda.is_available()
is_bf16_available = torch.cuda.is_bf16_supported()
is_npu_available = False
@ -29,6 +31,17 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _infer_dtype() -> torch.dtype:
if is_npu_available:
return torch.float16
elif is_bf16_available:
return torch.bfloat16
elif is_fp16_available:
return torch.float16
else:
return torch.float32
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@ -211,7 +224,7 @@ def get_train_args(
elif training_args.fp16:
model_args.compute_dtype = torch.float16
else:
model_args.compute_dtype = torch.float32
model_args.compute_dtype = _infer_dtype()
model_args.model_max_length = data_args.cutoff_len
@ -252,11 +265,6 @@ def get_infer_args(
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
# auto-detect cuda capability
if is_npu_available:
model_args.compute_dtype = torch.float16
elif is_bf16_available:
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float16
model_args.compute_dtype = _infer_dtype()
return model_args, data_args, finetuning_args, generating_args