This commit is contained in:
hiyouga 2024-08-05 23:48:19 +08:00
parent c2921b9960
commit b7ca6c8dc1
13 changed files with 111 additions and 69 deletions

View File

@ -300,20 +300,20 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.11 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 | | torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.41.2 | | transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.4 | | trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 | | vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.5.9 | | flash-attn | 2.3.0 | 2.6.3 |
### Hardware Requirement ### Hardware Requirement

View File

@ -166,8 +166,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
@ -300,20 +300,20 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.11 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 | | torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.41.2 | | transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.4 | | trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 | | vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.5.9 | | flash-attn | 2.3.0 | 2.6.3 |
### 硬件依赖 ### 硬件依赖

View File

@ -1,8 +1,8 @@
transformers>=4.41.2 transformers>=4.41.2,<=4.43.4
datasets>=2.16.0 datasets>=2.16.0,<=2.20.0
accelerate>=0.30.1 accelerate>=0.30.1,<=0.32.0
peft>=0.11.1 peft>=0.11.1,<=0.12.0
trl>=0.8.6 trl>=0.8.6,<=0.9.6
gradio>=4.0.0 gradio>=4.0.0
pandas>=2.0.0 pandas>=2.0.0
scipy scipy

View File

@ -20,19 +20,17 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2 transformers>=4.41.2,<=4.43.4
datasets>=2.16.0 datasets>=2.16.0,<=2.20.0
accelerate>=0.30.1 accelerate>=0.30.1,<=0.32.0
peft>=0.11.1 peft>=0.11.1,<=0.12.0
trl>=0.8.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.42.4 transformers>=4.41.2,<=4.43.4
packing: packing:
transformers>=4.41.2,<=4.42.4 transformers>=4.41.2,<=4.43.4
patcher:
transformers==4.41.2 (chatglm)
""" """
from .cli import VERSION from .cli import VERSION

View File

@ -535,10 +535,6 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-2b", DownloadSource.DEFAULT: "google/gemma-2-2b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
}, },
"Gemma-2-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
},
"Gemma-2-9B": { "Gemma-2-9B": {
DownloadSource.DEFAULT: "google/gemma-2-9b", DownloadSource.DEFAULT: "google/gemma-2-9b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
@ -547,6 +543,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b", DownloadSource.DEFAULT: "google/gemma-2-27b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
}, },
"Gemma-2-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
},
"Gemma-2-9B-Chat": { "Gemma-2-9B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",

View File

@ -79,11 +79,11 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") require_version("datasets>=2.16.0,<=2.20.0", "To fix: pip install datasets>=2.16.0,<=2.20.0")
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") require_version("accelerate>=0.30.1,<=0.32.0", "To fix: pip install accelerate>=0.30.1,<=0.32.0")
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:

View File

@ -70,6 +70,11 @@ def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
@lru_cache
def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")

View File

@ -36,7 +36,7 @@ def configure_attn_implementation(
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2" model_args.flash_attn = "fa2"
else: else:

View File

@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward( def llama_attention_forward(
self: "LlamaAttention", self: "LlamaAttention",
hidden_states: torch.Tensor, hidden_states: "torch.Tensor",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: "torch.Tensor" = self.q_proj(hidden_states)
@ -68,7 +70,11 @@ def llama_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@ -130,14 +136,15 @@ def llama_attention_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward( def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2", self: "LlamaFlashAttention2",
hidden_states: torch.Tensor, hidden_states: "torch.Tensor",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False output_attentions = False
@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: "torch.Tensor" = self._flash_attention_forward( if is_transformers_version_greater_than_4_43():
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate from transformers.modeling_flash_attention_utils import _flash_attention_forward
)
attn_output: "torch.Tensor" = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.size(1),
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward( def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention", self: "LlamaSdpaAttention",
hidden_states: torch.Tensor, hidden_states: "torch.Tensor",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
if output_attentions: if output_attentions:
transformers_logger.warning_once( transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING: if TYPE_CHECKING:
@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere": if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon": elif model_type == "falcon":

View File

@ -162,11 +162,12 @@ class PissaConvertCallback(TrainerCallback):
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained( model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
) ) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default") model.set_adapter("default")
if "pissa_init" in model.peft_config.keys(): if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init") model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)

View File

@ -71,7 +71,7 @@ def create_web_demo() -> "gr.Blocks":
engine = Engine(pure_chat=True) engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo: with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"]) lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
engine.manager.add_elems("top", dict(lang=lang)) engine.manager.add_elems("top", dict(lang=lang))
_, _, chat_elems = create_chat_box(engine, visible=True) _, _, chat_elems = create_chat_box(engine, visible=True)

View File

@ -362,7 +362,6 @@ LOCALES = {
"label": "학습률", "label": "학습률",
"info": "AdamW의 초기 학습률.", "info": "AdamW의 초기 학습률.",
}, },
}, },
"num_train_epochs": { "num_train_epochs": {
"en": { "en": {