fix generating args

This commit is contained in:
hiyouga 2023-06-13 01:33:56 +08:00
parent cec6524d6b
commit 531a3764d9
5 changed files with 20 additions and 16 deletions

View File

@ -30,8 +30,8 @@ def main():
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \ training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \ training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset # Split the dataset
if training_args.do_train: if training_args.do_train:

View File

@ -195,8 +195,6 @@ def load_pretrained(
bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type bnb_4bit_quant_type=model_args.quantization_type
) )
else:
raise NotImplementedError
is_mergeable = False is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@ -273,8 +271,8 @@ def prepare_args(
if training_args.do_predict and (not training_args.predict_with_generate): if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full": if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is incompatible with the full-parameter tuning.") raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -358,7 +356,14 @@ def prepare_data(
) )
elif dataset_attr.load_from == "file": elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
extension = dataset_attr.file_name.split(".")[-1] extension = dataset_attr.file_name.split(".")[-1]
if extension == "csv":
file_type = "csv"
elif extension == "json" or extension == "jsonl":
file_type = "json"
else:
file_type = "text"
if dataset_attr.file_sha1 is not None: if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1) checksum(data_file, dataset_attr.file_sha1)
@ -366,7 +371,7 @@ def prepare_data(
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset( raw_datasets = load_dataset(
extension if extension in ["csv", "json"] else "text", file_type,
data_files=data_file, data_files=data_file,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None use_auth_token=True if model_args.use_auth_token else None

View File

@ -87,6 +87,8 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
@ -125,7 +127,7 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
) )
num_beams: Optional[int] = field( eval_num_beams: Optional[int] = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
) )
@ -164,7 +166,7 @@ class DataTrainingArguments:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
file_name=dataset_info[name]["file_name"], file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None file_sha1=dataset_info[name].get("file_sha1", None)
) )
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
@ -262,7 +264,7 @@ class GeneratingArguments:
default=50, default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
) )
infer_num_beams: Optional[int] = field( num_beams: Optional[int] = field(
default=1, default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."} metadata={"help": "Number of beams for beam search. 1 means no beam search."}
) )
@ -276,7 +278,4 @@ class GeneratingArguments:
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
data_dict = asdict(self) return asdict(self)
num_beams = data_dict.pop("infer_num_beams")
data_dict["num_beams"] = num_beams
return data_dict

View File

@ -81,7 +81,7 @@ class PeftTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: FinetuningArguments, **kwargs): def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.") logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")) os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))

View File

@ -87,7 +87,7 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
"do_sample": True, "do_sample": True,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
"num_beams": generating_args.infer_num_beams, "num_beams": generating_args.num_beams,
"max_length": max_length, "max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty, "repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(), "logits_processor": get_logits_processor(),