forked from p04798526/LLaMA-Factory-Mirror
fix generating args
This commit is contained in:
parent
cec6524d6b
commit
531a3764d9
|
@ -30,8 +30,8 @@ def main():
|
|||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
training_args.generation_max_length is not None else data_args.max_target_length
|
||||
training_args.generation_num_beams = data_args.num_beams if \
|
||||
data_args.num_beams is not None else training_args.generation_num_beams
|
||||
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
|
|
|
@ -195,8 +195,6 @@ def load_pretrained(
|
|||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
@ -273,8 +271,8 @@ def prepare_args(
|
|||
if training_args.do_predict and (not training_args.predict_with_generate):
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full":
|
||||
raise ValueError("Quantization is incompatible with the full-parameter tuning.")
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
@ -358,7 +356,14 @@ def prepare_data(
|
|||
)
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
if extension == "csv":
|
||||
file_type = "csv"
|
||||
elif extension == "json" or extension == "jsonl":
|
||||
file_type = "json"
|
||||
else:
|
||||
file_type = "text"
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
checksum(data_file, dataset_attr.file_sha1)
|
||||
|
@ -366,7 +371,7 @@ def prepare_data(
|
|||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension if extension in ["csv", "json"] else "text",
|
||||
file_type,
|
||||
data_files=data_file,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
|
|
|
@ -87,6 +87,8 @@ class ModelArguments:
|
|||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
|
@ -125,7 +127,7 @@ class DataTrainingArguments:
|
|||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
eval_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||
)
|
||||
|
@ -164,7 +166,7 @@ class DataTrainingArguments:
|
|||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
file_name=dataset_info[name]["file_name"],
|
||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
||||
file_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
|
@ -262,7 +264,7 @@ class GeneratingArguments:
|
|||
default=50,
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||
)
|
||||
infer_num_beams: Optional[int] = field(
|
||||
num_beams: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
|
@ -276,7 +278,4 @@ class GeneratingArguments:
|
|||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data_dict = asdict(self)
|
||||
num_beams = data_dict.pop("infer_num_beams")
|
||||
data_dict["num_beams"] = num_beams
|
||||
return data_dict
|
||||
return asdict(self)
|
||||
|
|
|
@ -81,7 +81,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
|
|||
"do_sample": True,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"num_beams": generating_args.infer_num_beams,
|
||||
"num_beams": generating_args.num_beams,
|
||||
"max_length": max_length,
|
||||
"repetition_penalty": generating_args.repetition_penalty,
|
||||
"logits_processor": get_logits_processor(),
|
||||
|
|
Loading…
Reference in New Issue