forked from p04798526/LLaMA-Factory-Mirror
fix #1000
This commit is contained in:
parent
044d4425b4
commit
5cc7a44784
|
@ -369,8 +369,7 @@ python src/export_model.py \
|
|||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export \
|
||||
--fp16
|
||||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
### API Demo
|
||||
|
|
|
@ -368,8 +368,7 @@ python src/export_model.py \
|
|||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export \
|
||||
--fp16
|
||||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
### API 服务
|
||||
|
|
|
@ -9,10 +9,12 @@ from transformers.utils.versions import require_version
|
|||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
try:
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
|
||||
is_fp16_available = is_torch_cuda_available()
|
||||
is_bf16_available = is_torch_bf16_gpu_available()
|
||||
is_npu_available = is_torch_npu_available()
|
||||
except ImportError:
|
||||
is_fp16_available = torch.cuda.is_available()
|
||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
is_npu_available = False
|
||||
|
||||
|
@ -29,6 +31,17 @@ from llmtuner.hparams import (
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _infer_dtype() -> torch.dtype:
|
||||
if is_npu_available:
|
||||
return torch.float16
|
||||
elif is_bf16_available:
|
||||
return torch.bfloat16
|
||||
elif is_fp16_available:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
|
@ -211,7 +224,7 @@ def get_train_args(
|
|||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
|
@ -252,11 +265,6 @@ def get_infer_args(
|
|||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
# auto-detect cuda capability
|
||||
if is_npu_available:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif is_bf16_available:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
|
Loading…
Reference in New Issue