update parser
This commit is contained in:
parent
8664262cde
commit
be99799413
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, Generator, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
_DESCRIPTION = "An example of dataset."
|
_DESCRIPTION = "An example of dataset."
|
||||||
|
@ -40,7 +40,7 @@ class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
|
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||||
for key, example in enumerate(example_dataset):
|
for key, example in enumerate(example_dataset):
|
||||||
yield key, example
|
yield key, example
|
||||||
|
|
|
@ -73,19 +73,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
|
||||||
if finetuning_args.stage != "sft":
|
|
||||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
|
||||||
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
|
||||||
raise ValueError("vLLM engine does not support quantization.")
|
|
||||||
|
|
||||||
if model_args.rope_scaling is not None:
|
|
||||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
|
@ -154,6 +141,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if training_args.fp16 or training_args.bf16:
|
if training_args.fp16 or training_args.bf16:
|
||||||
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
|
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
|
||||||
|
|
||||||
|
if model_args.infer_backend == "vllm":
|
||||||
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -252,12 +242,27 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
model_args.device_map = "auto"
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.infer_backend == "vllm":
|
||||||
|
if finetuning_args.stage != "sft":
|
||||||
|
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
raise ValueError("vLLM engine does not support quantization.")
|
||||||
|
|
||||||
|
if model_args.rope_scaling is not None:
|
||||||
|
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||||
|
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
@ -265,12 +270,17 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
model_args.device_map = "auto"
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.infer_backend == "vllm":
|
||||||
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
transformers.set_seed(eval_args.seed)
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
return model_args, data_args, eval_args, finetuning_args
|
return model_args, data_args, eval_args, finetuning_args
|
||||||
|
|
Loading…
Reference in New Issue