diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index fa2989ef..49b86345 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -156,5 +156,6 @@ class DataArguments: dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.ranking = dataset_info[name].get("ranking", False) + dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") dataset_attr.system_prompt = prompt_list[i] self.dataset_list.append(dataset_attr) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index f8b935fb..c04a5c36 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -28,7 +28,7 @@ class GeneratingArguments: metadata={"help": "Number of beams for beam search. 1 means no beam search."} ) max_length: Optional[int] = field( - default=None, + default=512, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} ) max_new_tokens: Optional[int] = field( @@ -46,6 +46,8 @@ class GeneratingArguments: def to_dict(self) -> Dict[str, Any]: args = asdict(self) - if args.get("max_new_tokens", None): + if args.get("max_new_tokens", -1) > 0: args.pop("max_length", None) + else: + args.pop("max_new_tokens", None) return args