fix #1184
This commit is contained in:
parent
b240b6792f
commit
af18b0dce7
|
@ -22,6 +22,9 @@ def preprocess_dataset(
|
|||
column_names = list(next(iter(dataset)).keys())
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
|
@ -32,13 +35,12 @@ def preprocess_dataset(
|
|||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
|
||||
kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
|
||||
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
|
||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
||||
setattr(tokenizer, "add_eos_token", True)
|
||||
|
||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||
|
@ -74,7 +76,9 @@ def preprocess_dataset(
|
|||
if len(target_ids) > max_target_len:
|
||||
target_ids = target_ids[:max_target_len]
|
||||
|
||||
if turn_idx != 0 and template.efficient_eos:
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
@ -97,15 +101,17 @@ def preprocess_dataset(
|
|||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<bos> X Y <eos>`
|
||||
# we do not mask the inputs in packed training.
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for query, response, history, system in construct_example(examples):
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if turn_idx != 0 and template.efficient_eos:
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
@ -229,5 +235,9 @@ def preprocess_dataset(
|
|||
**kwargs
|
||||
)
|
||||
|
||||
print_function(next(iter(dataset)))
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise ValueError("Empty dataset!")
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -31,7 +31,7 @@ class DataArguments:
|
|||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_en",
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
|
@ -46,13 +46,17 @@ class DataArguments:
|
|||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
)
|
||||
train_on_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
metadata={"help": "Enable dataset streaming."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
|
@ -95,10 +99,20 @@ class DataArguments:
|
|||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -19,6 +18,10 @@ class ModelArguments:
|
|||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
|
@ -76,6 +79,9 @@ class ModelArguments:
|
|||
self.compute_dtype = None
|
||||
self.model_max_length = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
|
|
|
@ -122,9 +122,6 @@ def get_train_args(
|
|||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and training_args.deepspeed is not None:
|
||||
raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.")
|
||||
|
||||
if general_args.stage == "ppo" and data_args.streaming:
|
||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||
|
||||
|
@ -134,9 +131,6 @@ def get_train_args(
|
|||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
|
@ -166,11 +160,6 @@ def get_train_args(
|
|||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
# postprocess data_args
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
|
|
Loading…
Reference in New Issue