This commit is contained in:
hiyouga 2024-06-13 00:48:44 +08:00
parent 947a34f53b
commit 713fde4259
4 changed files with 18 additions and 15 deletions

View File

@ -448,7 +448,7 @@ docker run -it --gpus=all \
```bash
docker-compose up -d
docker-compose exec -it llamafactory bash
docker-compose exec llamafactory bash
```
<details><summary>Details about volume</summary>

View File

@ -448,7 +448,7 @@ docker run -it --gpus=all \
```bash
docker-compose up -d
docker-compose exec -it llamafactory bash
docker-compose exec llamafactory bash
```
<details><summary>数据卷详情</summary>

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import List, Literal, Optional
@dataclass
@ -319,20 +319,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.galore_target = split_arg(self.galore_target)
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")

View File

@ -1,9 +1,13 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass
class ModelArguments:
r"""
@ -194,9 +198,9 @@ class ModelArguments:
)
def __post_init__(self):
self.compute_dtype = None
self.device_map = None
self.model_max_length = None
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")