update parser

This commit is contained in:
hiyouga 2024-03-10 13:35:20 +08:00
parent 8664262cde
commit be99799413
2 changed files with 29 additions and 19 deletions

View File

@ -1,6 +1,6 @@
import json import json
import datasets import datasets
from typing import Any, Dict, List from typing import Any, Dict, Generator, List, Tuple
_DESCRIPTION = "An example of dataset." _DESCRIPTION = "An example of dataset."
@ -40,7 +40,7 @@ class ExampleDataset(datasets.GeneratorBasedBuilder):
) )
] ]
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8")) example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset): for key, example in enumerate(example_dataset):
yield key, example yield key, example

View File

@ -73,19 +73,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
@ -154,6 +141,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.fp16 or training_args.bf16: if training_args.fp16 or training_args.bf16:
raise ValueError("Turn off mixed precision training when using `pure_bf16`.") raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
if ( if (
@ -252,12 +242,27 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
model_args.device_map = "auto"
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args)
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
@ -265,12 +270,17 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
model_args.device_map = "auto"
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
model_args.device_map = "auto"
transformers.set_seed(eval_args.seed) transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args return model_args, data_args, eval_args, finetuning_args