forked from p04798526/LLaMA-Factory-Mirror
Update finetuning_args.py
This commit is contained in:
parent
57dcd91e17
commit
ec899cccf3
|
@ -163,47 +163,6 @@ class RLHFArguments:
|
|||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments for BAdam optimizer.
|
||||
"""
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use BAdam optimizer."},
|
||||
)
|
||||
badam_mode: Literal["layer", "ratio"] = field(
|
||||
default="layer",
|
||||
metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."},
|
||||
)
|
||||
|
||||
# ======== Arguments for layer-wise update ========
|
||||
start_block: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for block-wise fine-tuning."}
|
||||
)
|
||||
switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."}
|
||||
)
|
||||
switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update."}
|
||||
)
|
||||
|
||||
# ======== Arguments for ratio-wise update ========
|
||||
badam_update_ratio: float = field(
|
||||
default=0.,
|
||||
metadata={"help": "The ratio of the update for the BAdam optimizer."}
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."}
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"}
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
|
@ -213,7 +172,7 @@ class GaloreArguments:
|
|||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="all",
|
||||
|
@ -244,6 +203,53 @@ class GaloreArguments:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the BAdam optimizer."},
|
||||
)
|
||||
badam_mode: Literal["layer", "ratio"] = field(
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_update_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={
|
||||
"help": """The mode of the mask for BAdam optimizer. \
|
||||
`adjacent` means that the trainable parameters are adjacent to each other, \
|
||||
`scatter` means that trainable parameters are randomly choosed from the weight."""
|
||||
},
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": """The verbosity level of BAdam optimizer. \
|
||||
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue