diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 09ed25d3..a153e440 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -57,7 +57,15 @@ class FinetuningArguments: ) ppo_score_norm: Optional[bool] = field( default=False, - metadata={"help": "Use score normalization in PPO Training."} + metadata={"help": "Use score normalization in PPO training."} + ) + ppo_logger: Optional[str] = field( + default=None, + metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} + ) + ppo_target: Optional[float] = field( + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."} ) dpo_beta: Optional[float] = field( default=0.1, diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index a8ed914a..118138c2 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -201,7 +201,9 @@ def get_train_args( ) # postprocess model_args - model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) + model_args.compute_dtype = ( + torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32) + ) model_args.model_max_length = data_args.cutoff_len # Log on each process the small summary: diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 85e48279..44df9672 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -206,7 +206,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): replace_model(unwrapped_model, target="reward") batch = self.prepare_model_inputs(queries, responses) - with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16 + with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 @@ -251,7 +251,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): input_ids = input_kwargs["input_ids"] attention_mask = input_kwargs["attention_mask"] - with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16 + with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 logits, _, values = model(**input_kwargs) if values.size(0) != input_ids.size(0): # adapt to chatglm2 diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 7fd2f29b..1dd3205e 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -42,15 +42,14 @@ def run_ppo( ppo_epochs=1, max_grad_norm=training_args.max_grad_norm, seed=training_args.seed, - log_with=training_args.report_to, optimize_cuda_cache=True, + target=finetuning_args.ppo_target, + log_with=finetuning_args.ppo_logger, + use_score_scaling=finetuning_args.ppo_score_norm, + use_score_norm=finetuning_args.ppo_score_norm, accelerator_kwargs={"step_scheduler_with_optimizer": False} ) - if finetuning_args.ppo_score_norm: - ppo_config.use_score_scaling = True - ppo_config.use_score_norm = True - optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) total_train_batch_size = ( training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size