parent
28d5de7e78
commit
9ce1b0e2f2
|
@ -2,7 +2,7 @@ torch>=1.13.1
|
||||||
transformers>=4.31.0,<4.35.0
|
transformers>=4.31.0,<4.35.0
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft==0.6.0
|
peft>=0.7.0
|
||||||
trl>=0.7.4
|
trl>=0.7.4
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
|
|
|
@ -102,6 +102,9 @@ def init_adapter(
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ logger = get_logger(__name__)
|
||||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft==0.6.0", "To fix: pip install peft==0.6.0")
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue