Update finetuning_args.py

This commit is contained in:
hoshi-hiyouga 2024-04-16 17:26:30 +08:00 committed by GitHub
parent 57dcd91e17
commit ec899cccf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 48 additions and 42 deletions

View File

@ -163,47 +163,6 @@ class RLHFArguments:
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
@dataclass
class BAdamArgument:
r"""
Arguments for BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."},
)
# ======== Arguments for layer-wise update ========
start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for block-wise fine-tuning."}
)
switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."}
)
switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update."}
)
# ======== Arguments for ratio-wise update ========
badam_update_ratio: float = field(
default=0.,
metadata={"help": "The ratio of the update for the BAdam optimizer."}
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."}
)
badam_verbose: int = field(
default=0,
metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"}
)
@dataclass
class GaloreArguments:
@ -213,7 +172,7 @@ class GaloreArguments:
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."},
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
@ -244,6 +203,53 @@ class GaloreArguments:
)
@dataclass
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""