fix ppo args
This commit is contained in:
parent
2818af0b09
commit
11bd271364
|
@ -57,7 +57,15 @@ class FinetuningArguments:
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
ppo_score_norm: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use score normalization in PPO Training."}
|
metadata={"help": "Use score normalization in PPO training."}
|
||||||
|
)
|
||||||
|
ppo_logger: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||||
|
)
|
||||||
|
ppo_target: Optional[float] = field(
|
||||||
|
default=6.0,
|
||||||
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||||
)
|
)
|
||||||
dpo_beta: Optional[float] = field(
|
dpo_beta: Optional[float] = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
|
|
|
@ -201,7 +201,9 @@ def get_train_args(
|
||||||
)
|
)
|
||||||
|
|
||||||
# postprocess model_args
|
# postprocess model_args
|
||||||
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
model_args.compute_dtype = (
|
||||||
|
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
|
||||||
|
)
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
|
|
|
@ -206,7 +206,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||||
|
@ -251,7 +251,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
input_ids = input_kwargs["input_ids"]
|
input_ids = input_kwargs["input_ids"]
|
||||||
attention_mask = input_kwargs["attention_mask"]
|
attention_mask = input_kwargs["attention_mask"]
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
logits, _, values = model(**input_kwargs)
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||||
|
|
|
@ -42,15 +42,14 @@ def run_ppo(
|
||||||
ppo_epochs=1,
|
ppo_epochs=1,
|
||||||
max_grad_norm=training_args.max_grad_norm,
|
max_grad_norm=training_args.max_grad_norm,
|
||||||
seed=training_args.seed,
|
seed=training_args.seed,
|
||||||
log_with=training_args.report_to,
|
|
||||||
optimize_cuda_cache=True,
|
optimize_cuda_cache=True,
|
||||||
|
target=finetuning_args.ppo_target,
|
||||||
|
log_with=finetuning_args.ppo_logger,
|
||||||
|
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||||
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.ppo_score_norm:
|
|
||||||
ppo_config.use_score_scaling = True
|
|
||||||
ppo_config.use_score_norm = True
|
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
total_train_batch_size = (
|
total_train_batch_size = (
|
||||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
|
Loading…
Reference in New Issue