allow non-packing pretraining

This commit is contained in:
hiyouga 2024-03-09 22:21:46 +08:00
parent 412c52e325
commit bdb496644c
22 changed files with 64 additions and 67 deletions

View File

@ -59,7 +59,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: Optional[list] = []
tools: list = []
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None

View File

@ -21,8 +21,11 @@ logger = get_logger(__name__)
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@ -245,7 +248,7 @@ def get_preprocess_and_print_func(
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)

View File

@ -36,8 +36,8 @@ class Template:
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 1,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
@ -56,8 +56,8 @@ class Template:
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 1,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
@ -207,11 +207,11 @@ def _register_template(
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
default_system: Optional[str] = "",
stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False,
force_system: Optional[bool] = False,
default_system: str = "",
stop_words: List[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
) -> None:
r"""
Registers a chat template.
@ -279,9 +279,7 @@ def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'")
def _convert_slots_to_jinja(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: Optional[str] = "content"
) -> str:
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):

View File

@ -1,7 +1,7 @@
import json
import math
import os
from typing import List, Optional
from typing import List
from transformers.trainer import TRAINER_STATE_NAME
@ -30,7 +30,7 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)

View File

@ -78,9 +78,11 @@ class DataArguments:
default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
)
sft_packing: bool = field(
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
packing: Optional[bool] = field(
default=None,
metadata={
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
)
cache_path: Optional[str] = field(
default=None,

View File

@ -135,7 +135,6 @@ class ModelArguments:
)
def __post_init__(self):
self.aqlm_optimization = None
self.compute_dtype = None
self.device_map = None
self.model_max_length = None

View File

@ -230,7 +230,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.cutoff_len
model_args.aqlm_optimization = not training_args.predict_with_generate
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary:
logger.info(
@ -253,7 +253,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = False
model_args.device_map = "auto"
if data_args.template is None:
@ -267,7 +266,6 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = True
model_args.device_map = "auto"
if data_args.template is None:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
@ -52,8 +52,8 @@ def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False,
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model. Must after load_tokenizer.
@ -137,8 +137,8 @@ def load_model(
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False,
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.

View File

@ -3,7 +3,7 @@ import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from datasets import load_dataset
@ -219,7 +219,7 @@ def _configure_quantization(
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:

View File

@ -22,7 +22,7 @@ class CustomDPOTrainer(DPOTrainer):
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
@ -95,7 +95,7 @@ class CustomDPOTrainer(DPOTrainer):
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train",
train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.

View File

@ -292,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
r"""

View File

@ -1,6 +1,6 @@
import json
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
import torch
from transformers import Trainer
@ -26,7 +26,7 @@ class PairwiseTrainer(Trainer):
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

View File

@ -46,7 +46,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.

View File

@ -18,9 +18,7 @@ if TYPE_CHECKING:
class WebChatModel(ChatModel):
def __init__(
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
) -> None:
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None

View File

@ -104,10 +104,12 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def list_dataset(
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
) -> Dict[str, Any]:
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.update(value=[], choices=datasets)
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
return gr.update(value=(TRAINING_STAGES[training_stage] == "pt"))

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
@ -14,7 +14,7 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine", visible: Optional[bool] = False
engine: "Engine", visible: bool = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
@ -12,7 +12,7 @@ if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
@ -44,7 +44,7 @@ def create_top() -> Dict[str, "Component"]:
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
return lang, dict(
lang=lang,
model_name=model_name,
model_path=model_path,

View File

@ -4,7 +4,7 @@ import gradio as gr
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
@ -78,7 +78,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
resize_vocab = gr.Checkbox()
sft_packing = gr.Checkbox()
packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
shift_attn = gr.Checkbox()
@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha,
optim,
resize_vocab,
sft_packing,
packing,
upcast_layernorm,
use_llama_pro,
shift_attn,
@ -106,7 +106,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha=neftune_alpha,
optim=optim,
resize_vocab=resize_vocab,
sft_packing=sft_packing,
packing=packing,
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
@ -166,7 +166,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
queue=False,
)
).then(autoset_packing, [training_stage], [packing], queue=False)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Generator, Optional
from typing import Any, Dict, Generator
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
@ -12,7 +12,7 @@ from .utils import get_time
class Engine:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode
self.pure_chat = pure_chat
self.manager = Manager()

View File

@ -1,5 +1,3 @@
from typing import Optional
import gradio as gr
from transformers.utils.versions import require_version
@ -19,7 +17,7 @@ from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
@ -31,8 +29,7 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
lang, engine.manager.all_elems["top"] = create_top()
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)

View File

@ -480,18 +480,18 @@ LOCALES = {
"info": "更改分词器词表和嵌入层的大小。",
},
},
"sft_packing": {
"packing": {
"en": {
"label": "Pack sequences",
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
"info": "Pack sequences into samples of fixed length.",
},
"ru": {
"label": "Упаковка последовательностей",
"info": "Упаковка последовательностей в образцы фиксированной длины при контролируемой тонкой настройке.",
"info": "Упаковка последовательностей в образцы фиксированной длины.",
},
"zh": {
"label": "序列打包",
"info": "在指令监督微调时将序列打包为等长样本。",
"info": "将序列打包为等长样本。",
},
},
"upcast_layernorm": {

View File

@ -2,7 +2,7 @@ import logging
import os
import time
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
import gradio as gr
import transformers
@ -25,7 +25,7 @@ if TYPE_CHECKING:
class Runner:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
@ -136,7 +136,7 @@ class Runner:
neftune_noise_alpha=get("train.neftune_alpha") or None,
optim=get("train.optim"),
resize_vocab=get("train.resize_vocab"),
sft_packing=get("train.sft_packing"),
packing=get("train.packing"),
upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"),