fix initializing data arguments
This commit is contained in:
parent
2e01abfda5
commit
18f87c1b25
27
README.md
27
README.md
|
@ -202,6 +202,33 @@ accelerate config # configure the environment
|
||||||
accelerate launch src/train_XX.py # arguments (same as above)
|
accelerate launch src/train_XX.py # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: false
|
||||||
|
zero_stage: 2
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### Evaluation (BLEU and ROUGE_CHINESE)
|
### Evaluation (BLEU and ROUGE_CHINESE)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -103,11 +103,10 @@ def _init_adapter(
|
||||||
lastest_checkpoint = None
|
lastest_checkpoint = None
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||||
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
|
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
||||||
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||||
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||||
please specify `--finetuning_type full/freeze` instead.")
|
|
||||||
|
|
||||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
|
@ -267,6 +266,8 @@ def prepare_args(
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||||
|
data_args.init_for_training()
|
||||||
|
|
||||||
if stage != "sft" and training_args.predict_with_generate:
|
if stage != "sft" and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
||||||
|
|
||||||
|
|
|
@ -134,7 +134,7 @@ class DataTrainingArguments:
|
||||||
)
|
)
|
||||||
source_prefix: Optional[str] = field(
|
source_prefix: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."}
|
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||||
)
|
)
|
||||||
dev_ratio: Optional[float] = field(
|
dev_ratio: Optional[float] = field(
|
||||||
default=0,
|
default=0,
|
||||||
|
@ -145,7 +145,7 @@ class DataTrainingArguments:
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self): # support mixing multiple datasets
|
def init_for_training(self): # support mixing multiple datasets
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
|
|
Loading…
Reference in New Issue