fix #944
This commit is contained in:
parent
ace3f85a72
commit
338b8664ed
|
@ -44,10 +44,10 @@ def preprocess_dataset(
|
|||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.max_source_length
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
|
@ -58,7 +58,6 @@ def preprocess_dataset(
|
|||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for query, response, history, system in construct_example(examples):
|
||||
input_ids, labels = [], []
|
||||
|
@ -66,13 +65,14 @@ def preprocess_dataset(
|
|||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
total_len = len(source_ids) + len(target_ids)
|
||||
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
|
||||
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||
break
|
||||
if len(source_ids) > max_source_len:
|
||||
source_ids = source_ids[:max_source_len]
|
||||
if len(target_ids) > max_target_len:
|
||||
target_ids = target_ids[:max_target_len]
|
||||
|
||||
if turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
|
@ -86,6 +86,10 @@ def preprocess_dataset(
|
|||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
@ -97,19 +101,19 @@ def preprocess_dataset(
|
|||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, system in construct_example(examples):
|
||||
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if template.efficient_eos:
|
||||
target_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||
model_inputs["labels"].append(target_ids)
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
if len(labels) > data_args.cutoff_len:
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -120,17 +124,21 @@ def preprocess_dataset(
|
|||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||
|
||||
if len(prompt_ids) > data_args.max_source_length:
|
||||
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||
if len(chosen_ids) > data_args.max_target_length:
|
||||
chosen_ids = chosen_ids[:data_args.max_target_length]
|
||||
if len(rejected_ids) > data_args.max_target_length:
|
||||
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
|
||||
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
|
||||
|
||||
if len(prompt_ids) > max_source_len:
|
||||
prompt_ids = prompt_ids[:max_source_len]
|
||||
if len(chosen_ids) > max_target_len:
|
||||
chosen_ids = chosen_ids[:max_target_len]
|
||||
if len(rejected_ids) > max_target_len:
|
||||
rejected_ids = rejected_ids[:max_target_len]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
|
|
@ -42,12 +42,16 @@ class DataArguments:
|
|||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
cutoff_len: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=1024,
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
|
@ -66,14 +70,6 @@ class DataArguments:
|
|||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
|
|
|
@ -63,18 +63,10 @@ class ModelArguments:
|
|||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.compute_dtype is not None or self.model_max_length is not None:
|
||||
raise ValueError("These arguments cannot be specified.")
|
||||
self.compute_dtype = None
|
||||
self.model_max_length = None
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
|
|
@ -173,7 +173,7 @@ def load_model_and_tokenizer(
|
|||
)
|
||||
|
||||
# Disable custom generate method (for Qwen)
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
|
|
|
@ -213,7 +213,7 @@ def get_train_args(
|
|||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
|
|
|
@ -2,13 +2,13 @@ import os
|
|||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
|
@ -47,7 +47,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
def ppo_train(self) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
|
@ -81,9 +81,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
loss_meter = AverageMeter()
|
||||
|
@ -100,7 +98,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
|
||||
queries, responses = self.get_inputs(batch, generating_args)
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||
|
||||
|
@ -156,13 +154,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
def get_inputs(
|
||||
self,
|
||||
batch: Dict[str, torch.Tensor],
|
||||
length_sampler: Callable,
|
||||
generating_args: Dict[str, Any]
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
generating_args["max_new_tokens"] = length_sampler()
|
||||
gen_kwargs = dict(
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
|
|
|
@ -79,7 +79,7 @@ def run_ppo(
|
|||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
|
|
|
@ -28,12 +28,16 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
)
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
|
@ -55,11 +59,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
top_elems["system_prompt"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
cutoff_len,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict
|
||||
predict,
|
||||
max_new_tokens,
|
||||
top_p,
|
||||
temperature
|
||||
]
|
||||
|
||||
output_components = [
|
||||
|
@ -78,11 +84,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
predict=predict,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
|
|
|
@ -35,11 +35,11 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
)
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
|
@ -55,7 +55,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
flash_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Checkbox(value=False)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
|
@ -107,11 +108,11 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
training_stage,
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
cutoff_len,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
compute_type,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
|
@ -120,7 +121,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
compute_type,
|
||||
flash_attn,
|
||||
rope_scaling,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
|
@ -151,11 +153,11 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
cutoff_len=cutoff_len,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
compute_type=compute_type,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
|
@ -165,7 +167,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
compute_type=compute_type,
|
||||
flash_attn=flash_attn,
|
||||
rope_scaling=rope_scaling,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
|
|
|
@ -147,26 +147,16 @@ LOCALES = {
|
|||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"max_source_length": {
|
||||
"cutoff_len": {
|
||||
"en": {
|
||||
"label": "Max source length",
|
||||
"info": "Max tokens in source sequence."
|
||||
"label": "Cutoff length",
|
||||
"info": "Max tokens in input sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输入序列最大长度",
|
||||
"label": "截断长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"max_target_length": {
|
||||
"en": {
|
||||
"label": "Max target length",
|
||||
"info": "Max tokens in target sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输出序列最大长度",
|
||||
"info": "输出序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"learning_rate": {
|
||||
"en": {
|
||||
"label": "Learning rate",
|
||||
|
@ -197,6 +187,16 @@ LOCALES = {
|
|||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
|
@ -277,14 +277,20 @@ LOCALES = {
|
|||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"flash_attn": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
"label": "Use FlashAttention-2"
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
"label": "使用 FlashAttention-2"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "Use RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 RoPE 插值"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
|
|
|
@ -73,11 +73,11 @@ class Runner:
|
|||
training_stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
cutoff_len: int,
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
compute_type: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
|
@ -86,7 +86,8 @@ class Runner:
|
|||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
flash_attn: bool,
|
||||
rope_scaling: bool,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
|
@ -120,8 +121,7 @@ class Runner:
|
|||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
cutoff_len=cutoff_len,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
|
@ -132,6 +132,8 @@ class Runner:
|
|||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
flash_attn=flash_attn,
|
||||
rope_scaling="linear" if rope_scaling else None,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
|
@ -168,11 +170,13 @@ class Runner:
|
|||
system_prompt: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
cutoff_len: int,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
predict: bool,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
|
@ -200,10 +204,12 @@ class Runner:
|
|||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue