fix generating args
This commit is contained in:
parent
cec6524d6b
commit
531a3764d9
|
@ -30,8 +30,8 @@ def main():
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args.generation_max_length = training_args.generation_max_length if \
|
training_args.generation_max_length = training_args.generation_max_length if \
|
||||||
training_args.generation_max_length is not None else data_args.max_target_length
|
training_args.generation_max_length is not None else data_args.max_target_length
|
||||||
training_args.generation_num_beams = data_args.num_beams if \
|
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||||
data_args.num_beams is not None else training_args.generation_num_beams
|
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||||
|
|
||||||
# Split the dataset
|
# Split the dataset
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|
|
@ -195,8 +195,6 @@ def load_pretrained(
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
@ -273,8 +271,8 @@ def prepare_args(
|
||||||
if training_args.do_predict and (not training_args.predict_with_generate):
|
if training_args.do_predict and (not training_args.predict_with_generate):
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full":
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is incompatible with the full-parameter tuning.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
@ -358,7 +356,14 @@ def prepare_data(
|
||||||
)
|
)
|
||||||
elif dataset_attr.load_from == "file":
|
elif dataset_attr.load_from == "file":
|
||||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||||
|
|
||||||
extension = dataset_attr.file_name.split(".")[-1]
|
extension = dataset_attr.file_name.split(".")[-1]
|
||||||
|
if extension == "csv":
|
||||||
|
file_type = "csv"
|
||||||
|
elif extension == "json" or extension == "jsonl":
|
||||||
|
file_type = "json"
|
||||||
|
else:
|
||||||
|
file_type = "text"
|
||||||
|
|
||||||
if dataset_attr.file_sha1 is not None:
|
if dataset_attr.file_sha1 is not None:
|
||||||
checksum(data_file, dataset_attr.file_sha1)
|
checksum(data_file, dataset_attr.file_sha1)
|
||||||
|
@ -366,7 +371,7 @@ def prepare_data(
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
|
|
||||||
raw_datasets = load_dataset(
|
raw_datasets = load_dataset(
|
||||||
extension if extension in ["csv", "json"] else "text",
|
file_type,
|
||||||
data_files=data_file,
|
data_files=data_file,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None
|
use_auth_token=True if model_args.use_auth_token else None
|
||||||
|
|
|
@ -87,6 +87,8 @@ class ModelArguments:
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
|
||||||
|
if self.quantization_bit is not None:
|
||||||
|
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataTrainingArguments:
|
class DataTrainingArguments:
|
||||||
|
@ -125,7 +127,7 @@ class DataTrainingArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
eval_num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||||
)
|
)
|
||||||
|
@ -164,7 +166,7 @@ class DataTrainingArguments:
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr(
|
||||||
"file",
|
"file",
|
||||||
file_name=dataset_info[name]["file_name"],
|
file_name=dataset_info[name]["file_name"],
|
||||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
file_sha1=dataset_info[name].get("file_sha1", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
|
@ -262,7 +264,7 @@ class GeneratingArguments:
|
||||||
default=50,
|
default=50,
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||||
)
|
)
|
||||||
infer_num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||||
)
|
)
|
||||||
|
@ -276,7 +278,4 @@ class GeneratingArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
data_dict = asdict(self)
|
return asdict(self)
|
||||||
num_beams = data_dict.pop("infer_num_beams")
|
|
||||||
data_dict["num_beams"] = num_beams
|
|
||||||
return data_dict
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||||
logger.warning("Previous log file in this folder will be deleted.")
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"num_beams": generating_args.infer_num_beams,
|
"num_beams": generating_args.num_beams,
|
||||||
"max_length": max_length,
|
"max_length": max_length,
|
||||||
"repetition_penalty": generating_args.repetition_penalty,
|
"repetition_penalty": generating_args.repetition_penalty,
|
||||||
"logits_processor": get_logits_processor(),
|
"logits_processor": get_logits_processor(),
|
||||||
|
|
Loading…
Reference in New Issue