fix generating args

This commit is contained in:
hiyouga 2023-06-13 01:33:56 +08:00
parent cec6524d6b
commit 531a3764d9
5 changed files with 20 additions and 16 deletions

View File

@ -30,8 +30,8 @@ def main():
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:

View File

@ -195,8 +195,6 @@ def load_pretrained(
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
else:
raise NotImplementedError
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@ -273,8 +271,8 @@ def prepare_args(
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full":
raise ValueError("Quantization is incompatible with the full-parameter tuning.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -358,7 +356,14 @@ def prepare_data(
)
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
extension = dataset_attr.file_name.split(".")[-1]
if extension == "csv":
file_type = "csv"
elif extension == "json" or extension == "jsonl":
file_type = "json"
else:
file_type = "text"
if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1)
@ -366,7 +371,7 @@ def prepare_data(
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset(
extension if extension in ["csv", "json"] else "text",
file_type,
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None

View File

@ -87,6 +87,8 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
@dataclass
class DataTrainingArguments:
@ -125,7 +127,7 @@ class DataTrainingArguments:
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
)
num_beams: Optional[int] = field(
eval_num_beams: Optional[int] = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
)
@ -164,7 +166,7 @@ class DataTrainingArguments:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
file_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
@ -262,7 +264,7 @@ class GeneratingArguments:
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
)
infer_num_beams: Optional[int] = field(
num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
@ -276,7 +278,4 @@ class GeneratingArguments:
)
def to_dict(self) -> Dict[str, Any]:
data_dict = asdict(self)
num_beams = data_dict.pop("infer_num_beams")
data_dict["num_beams"] = num_beams
return data_dict
return asdict(self)

View File

@ -81,7 +81,7 @@ class PeftTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))

View File

@ -87,7 +87,7 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.infer_num_beams,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),