From 3e000c2b60c2e29bcafcf8d39c1a5d567ae2491c Mon Sep 17 00:00:00 2001 From: jiongxuc Date: Thu, 10 Aug 2023 14:57:12 +0800 Subject: [PATCH] huggingface login for projects must login while running --- requirements.txt | 1 + src/llmtuner/hparams/model_args.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/requirements.txt b/requirements.txt index 9b74b21d..c71b6c9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ pydantic==1.10.11 fastapi==0.95.1 sse-starlette matplotlib +huggingface_hub \ No newline at end of file diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 253d9839..94e882fe 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -1,6 +1,7 @@ import torch from typing import Literal, Optional from dataclasses import dataclass, field +from huggingface_hub.hf_api import HfFolder @dataclass @@ -63,6 +64,11 @@ class ModelArguments: default=False, metadata={"help": "Whether to plot the training loss after fine-tuning or not."} ) + hf_hub_token : Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + ) + def __post_init__(self): if self.checkpoint_dir is not None: # support merging multiple lora weights @@ -70,3 +76,6 @@ class ModelArguments: if self.quantization_bit is not None: assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." + + if self.use_auth_token == True and self.hf_hub_token != None: + HfFolder.save_token(self.hf_hub_token)