fix initializing data arguments

This commit is contained in:
hiyouga 2023-06-27 22:50:23 +08:00
parent 2e01abfda5
commit 18f87c1b25
3 changed files with 35 additions and 7 deletions

View File

@ -202,6 +202,33 @@ accelerate config # configure the environment
accelerate launch src/train_XX.py # arguments (same as above) accelerate launch src/train_XX.py # arguments (same as above)
``` ```
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
### Evaluation (BLEU and ROUGE_CHINESE) ### Evaluation (BLEU and ROUGE_CHINESE)
```bash ```bash

View File

@ -103,11 +103,10 @@ def _init_adapter(
lastest_checkpoint = None lastest_checkpoint = None
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)): assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])) "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)): assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \ "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
please specify `--finetuning_type full/freeze` instead.")
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
@ -267,6 +266,8 @@ def prepare_args(
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
if stage != "sft" and training_args.predict_with_generate: if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")

View File

@ -134,7 +134,7 @@ class DataTrainingArguments:
) )
source_prefix: Optional[str] = field( source_prefix: Optional[str] = field(
default=None, default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."} metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
) )
dev_ratio: Optional[float] = field( dev_ratio: Optional[float] = field(
default=0, default=0,
@ -145,7 +145,7 @@ class DataTrainingArguments:
metadata={"help": "Which template to use for constructing prompts in training and inference."} metadata={"help": "Which template to use for constructing prompts in training and inference."}
) )
def __post_init__(self): # support mixing multiple datasets def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")] dataset_names = [ds.strip() for ds in self.dataset.split(",")]
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)