fix lint
This commit is contained in:
parent
947a34f53b
commit
713fde4259
|
@ -448,7 +448,7 @@ docker run -it --gpus=all \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
docker-compose exec -it llamafactory bash
|
docker-compose exec llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>Details about volume</summary>
|
<details><summary>Details about volume</summary>
|
||||||
|
|
|
@ -448,7 +448,7 @@ docker run -it --gpus=all \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
docker-compose exec -it llamafactory bash
|
docker-compose exec llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>数据卷详情</summary>
|
<details><summary>数据卷详情</summary>
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -319,20 +319,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
return [item.strip() for item in arg.split(",")]
|
return [item.strip() for item in arg.split(",")]
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||||
self.galore_target = split_arg(self.galore_target)
|
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||||
|
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
r"""
|
r"""
|
||||||
|
@ -194,9 +198,9 @@ class ModelArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype: Optional["torch.dtype"] = None
|
||||||
self.device_map = None
|
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
self.model_max_length = None
|
self.model_max_length: Optional[int] = None
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
|
Loading…
Reference in New Issue