release v0.6.2

This commit is contained in:
hiyouga 2024-04-11 20:08:51 +08:00
parent 51d0a1a19e
commit 9d4c949461
10 changed files with 37 additions and 26 deletions

View File

@ -305,7 +305,7 @@ cd LLaMA-Factory
pip install -e .[metrics]
```
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, quality
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
<details><summary>For Windows users</summary>

View File

@ -305,7 +305,7 @@ cd LLaMA-Factory
pip install -e .[metrics]
```
可选的额外依赖项deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、quality
可选的额外依赖项deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
<details><summary>Windows 用户指南</summary>

View File

@ -30,6 +30,7 @@ extra_require = {
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["tiktoken", "transformers_stream_generator"],
"modelscope": ["modelscope"],
"quality": ["ruff"],
}

View File

@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.6.2.dev0"
__version__ = "0.6.2"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@ -6,7 +6,7 @@ from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import is_path_available
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
@ -81,7 +81,9 @@ def load_single_dataset(
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
@ -125,7 +127,7 @@ def get_dataset(
# Load tokenized dataset
if data_args.tokenized_path is not None:
if not is_path_available(data_args.tokenized_path):
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))

View File

@ -53,22 +53,35 @@ class DatasetAttr:
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
if data_args.dataset is not None:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
else:
dataset_names = []
if data_args.dataset_dir == "ONLINE":
dataset_info = None
else:
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None:
load_from = "ms_hub" if use_modelscope() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))

View File

@ -193,16 +193,11 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def is_path_available(path: os.PathLike) -> bool:
def has_tokenized_data(path: os.PathLike) -> bool:
r"""
Checks if the path is empty or not exist.
Checks if the path has a tokenized dataset.
"""
if not os.path.exists(path):
return True
elif os.path.isdir(path) and not os.listdir(path):
return True
else:
return False
return os.path.isdir(path) and len(os.listdir(path)) > 0
def torch_gc() -> None:

View File

@ -139,7 +139,6 @@ def init_adapter(
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
@ -151,6 +150,7 @@ def init_adapter(
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)

View File

@ -18,7 +18,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset})

View File

@ -23,7 +23,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=4, allow_custom_value=True)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})