support longlora for main branch

This commit is contained in:
hiyouga 2024-01-20 19:25:22 +08:00
parent bb92cdd0db
commit 38af076a75
7 changed files with 168 additions and 204 deletions

View File

@ -3,222 +3,166 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from ..packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
from transformers.models.llama.modeling_llama import (
Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv
)
logger = logging.get_logger(__name__)
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class LlamaShiftShortAttention(LlamaAttention):
def llama_torch_attn_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
past_key_value = (key_states, value_states) if use_cache else None
if getattr(self, "num_key_value_groups"):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
return attn_output, attn_weights, past_key_value
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
# cast to half precision
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(self.config.torch_dtype)
key_states = key_states.to(self.config.torch_dtype)
value_states = value_states.to(self.config.torch_dtype)
dropout_rate = self.attention_dropout if self.training else 0.0
if getattr(self, "num_key_value_groups", None):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
# -q_len: assumes left padding when q_len != kv_len
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
attn_output_unpad = flash_attn_varlen_func(
unpadded_q,
unpadded_k,
unpadded_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
)
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
if not output_attentions:
attn_weights = None
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
def _prepare_decoder_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: torch.Tensor,
inputs_embeds: torch.Tensor,
past_key_values_length: int
) -> torch.Tensor:
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask
def apply_llama_patch() -> None:
LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward

View File

@ -1,4 +1,5 @@
import torch
import inspect
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
@ -108,6 +109,9 @@ def init_adapter(
if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
unsloth_peft_kwargs["loftq_config"] = {}
if getattr(model.config, "model_type", None) == "llama":
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
elif getattr(model.config, "model_type", None) == "mistral":

View File

@ -15,6 +15,7 @@ from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
@ -23,7 +24,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
@ -39,26 +40,25 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None):
current_embedding_size = model.get_input_embeddings().weight.size(0)
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
if is_deepspeed_zero3_enabled():
import deepspeed
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context = nullcontext()
with context:
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
@ -136,6 +136,7 @@ def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")

View File

@ -1,4 +1,5 @@
import torch
from contextlib import nullcontext
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
@ -93,7 +94,8 @@ class CustomDPOTrainer(DPOTrainer):
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
average_log_prob=False,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
@ -118,20 +120,19 @@ class CustomDPOTrainer(DPOTrainer):
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
) = self.concatenated_forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,

View File

@ -95,7 +95,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3, allow_custom_value=True)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
@ -105,8 +106,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
queue=False
)
input_elems.update({dpo_beta, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
))
with gr.Row():
cmd_preview_btn = gr.Button()

View File

@ -421,6 +421,16 @@ LOCALES = {
"info": "DPO 损失函数中 beta 超参数大小。"
}
},
"dpo_ftx": {
"en": {
"label": "DPO-ftx weight",
"info": "The weight of SFT loss in the DPO-ftx."
},
"zh": {
"label": "DPO-ftx 权重",
"info": "DPO-ftx 中 SFT 损失的权重大小。"
}
},
"reward_model": {
"en": {
"label": "Reward model",

View File

@ -146,6 +146,7 @@ class Runner:
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
args["dpo_ftx"] = get("train.dpo_ftx")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")