diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 6ef5b933..35599f81 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -55,9 +55,13 @@ class LoraArguments: Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ Others choices: the same as LLaMA."} ) + lora_bf16_mode: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to train lora adapters in bf16 precision."} + ) create_new_adapter: Optional[bool] = field( default=False, - metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."} + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."} ) diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index f0d7ce21..83a63b96 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -125,7 +125,7 @@ def init_adapter( model = get_peft_model(model, lora_config) for param in filter(lambda p: p.requires_grad, model.parameters()): - param.data = param.data.to(torch.float32) + param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32) if model_args.adapter_name_or_path is not None: logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))