diff --git a/src/train_sft.py b/src/train_sft.py index 971fbb19..f82d254e 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -30,8 +30,8 @@ def main(): # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length if \ training_args.generation_max_length is not None else data_args.max_target_length - training_args.generation_num_beams = data_args.num_beams if \ - data_args.num_beams is not None else training_args.generation_num_beams + training_args.generation_num_beams = data_args.eval_num_beams if \ + data_args.eval_num_beams is not None else training_args.generation_num_beams # Split the dataset if training_args.do_train: diff --git a/src/utils/common.py b/src/utils/common.py index fb515b9a..b17e35da 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -195,8 +195,6 @@ def load_pretrained( bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_quant_type=model_args.quantization_type ) - else: - raise NotImplementedError is_mergeable = False config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) @@ -273,8 +271,8 @@ def prepare_args( if training_args.do_predict and (not training_args.predict_with_generate): raise ValueError("Please enable `predict_with_generate` to save model predictions.") - if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full": - raise ValueError("Quantization is incompatible with the full-parameter tuning.") + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -358,7 +356,14 @@ def prepare_data( ) elif dataset_attr.load_from == "file": data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) + extension = dataset_attr.file_name.split(".")[-1] + if extension == "csv": + file_type = "csv" + elif extension == "json" or extension == "jsonl": + file_type = "json" + else: + file_type = "text" if dataset_attr.file_sha1 is not None: checksum(data_file, dataset_attr.file_sha1) @@ -366,7 +371,7 @@ def prepare_data( logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") raw_datasets = load_dataset( - extension if extension in ["csv", "json"] else "text", + file_type, data_files=data_file, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None diff --git a/src/utils/config.py b/src/utils/config.py index c0f89217..5b919ec6 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -87,6 +87,8 @@ class ModelArguments: if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] + if self.quantization_bit is not None: + assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." @dataclass class DataTrainingArguments: @@ -125,7 +127,7 @@ class DataTrainingArguments: default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} ) - num_beams: Optional[int] = field( + eval_num_beams: Optional[int] = field( default=None, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} ) @@ -164,7 +166,7 @@ class DataTrainingArguments: dataset_attr = DatasetAttr( "file", file_name=dataset_info[name]["file_name"], - file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None + file_sha1=dataset_info[name].get("file_sha1", None) ) if "columns" in dataset_info[name]: @@ -262,7 +264,7 @@ class GeneratingArguments: default=50, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} ) - infer_num_beams: Optional[int] = field( + num_beams: Optional[int] = field( default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."} ) @@ -276,7 +278,4 @@ class GeneratingArguments: ) def to_dict(self) -> Dict[str, Any]: - data_dict = asdict(self) - num_beams = data_dict.pop("infer_num_beams") - data_dict["num_beams"] = num_beams - return data_dict + return asdict(self) diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index 160a1425..045278d3 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -81,7 +81,7 @@ class PeftTrainer(Seq2SeqTrainer): def __init__(self, finetuning_args: FinetuningArguments, **kwargs): super().__init__(**kwargs) self.finetuning_args = finetuning_args - if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): + if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): logger.warning("Previous log file in this folder will be deleted.") os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")) diff --git a/src/web_demo.py b/src/web_demo.py index 35ffcf56..2cceddd3 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -87,7 +87,7 @@ def predict(query, chatbot, max_length, top_p, temperature, history): "do_sample": True, "top_p": top_p, "temperature": temperature, - "num_beams": generating_args.infer_num_beams, + "num_beams": generating_args.num_beams, "max_length": max_length, "repetition_penalty": generating_args.repetition_penalty, "logits_processor": get_logits_processor(),