This commit is contained in:
hiyouga 2023-11-02 16:51:52 +08:00
parent 083787dbfe
commit dff128c7e3
2 changed files with 5 additions and 2 deletions

View File

@ -156,5 +156,6 @@ class DataArguments:
dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i] dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@ -28,7 +28,7 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."} metadata={"help": "Number of beams for beam search. 1 means no beam search."}
) )
max_length: Optional[int] = field( max_length: Optional[int] = field(
default=None, default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
) )
max_new_tokens: Optional[int] = field( max_new_tokens: Optional[int] = field(
@ -46,6 +46,8 @@ class GeneratingArguments:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
args = asdict(self) args = asdict(self)
if args.get("max_new_tokens", None): if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None) args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
return args return args