fix lint
This commit is contained in:
parent
947a34f53b
commit
713fde4259
|
@ -448,7 +448,7 @@ docker run -it --gpus=all \
|
|||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
docker-compose exec -it llamafactory bash
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
<details><summary>Details about volume</summary>
|
||||
|
|
|
@ -448,7 +448,7 @@ docker run -it --gpus=all \
|
|||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
docker-compose exec -it llamafactory bash
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
<details><summary>数据卷详情</summary>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -319,20 +319,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.galore_target = split_arg(self.galore_target)
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
r"""
|
||||
|
@ -194,9 +198,9 @@ class ModelArguments:
|
|||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
self.device_map = None
|
||||
self.model_max_length = None
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
|
Loading…
Reference in New Issue