rename files

This commit is contained in:
hiyouga 2024-06-07 00:09:06 +08:00
parent 45d8be8f93
commit 74f96efef9
43 changed files with 53 additions and 53 deletions

View File

@ -6,7 +6,7 @@ from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response

View File

@ -1,16 +1,16 @@
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",
"templates",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from ..extras.logging import get_logger
from .utils import Role
from .data_utils import Role
if TYPE_CHECKING:

View File

@ -10,10 +10,10 @@ from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import merge_dataset
if TYPE_CHECKING:

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:

View File

@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.logging import get_logger
from ..utils import Role
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:

View File

@ -2,8 +2,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .data_utils import Role, infer_max_len
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING:
@ -196,7 +196,7 @@ class Llama2Template(Template):
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {}
TEMPLATES: Dict[str, Template] = {}
def _register_template(
@ -248,7 +248,7 @@ def _register_template(
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
@ -348,9 +348,9 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["empty"] # placeholder
template = TEMPLATES["empty"] # placeholder
else:
template = templates.get(name, None)
template = TEMPLATES.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))

View File

@ -1,12 +1,12 @@
from .loader import load_config, load_model, load_tokenizer
from .utils.misc import find_all_linear_modules
from .utils.valuehead import load_valuehead_params
from .model_utils.misc import find_all_linear_modules
from .model_utils.valuehead import load_valuehead_params
__all__ = [
"load_config",
"load_model",
"load_tokenizer",
"load_valuehead_params",
"find_all_linear_modules",
"load_valuehead_params",
]

View File

@ -7,9 +7,9 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:

View File

@ -6,11 +6,11 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING:

View File

@ -10,15 +10,15 @@ from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
from .utils.attention import configure_attn_implementation, print_attn_implementation
from .utils.checkpointing import prepare_model_for_training
from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import prepare_valuehead_model
from .utils.visual import autocast_projector_dtype, configure_visual_model
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
if TYPE_CHECKING:

View File

@ -10,7 +10,7 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING:

View File

@ -7,7 +7,7 @@ from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push, create_ref_model
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import CustomDPOTrainer

View File

@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING:

View File

@ -5,7 +5,7 @@ from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push, create_ref_model
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import CustomKTOTrainer

View File

@ -19,8 +19,8 @@ from trl.models.utils import unwrap_model_for_generation
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..utils import create_custom_optimzer, create_custom_scheduler
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING:

View File

@ -9,7 +9,7 @@ from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_ref_model, create_reward_model
from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Optional
from transformers import Trainer
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:

View File

@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling
from ...data import get_dataset, split_dataset
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push
from ..trainer_utils import create_modelcard_and_push
from .trainer import CustomTrainer

View File

@ -7,7 +7,7 @@ import torch
from transformers import Trainer
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:

View File

@ -7,7 +7,7 @@ from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push
from ..trainer_utils import create_modelcard_and_push
from .metric import compute_accuracy
from .trainer import PairwiseTrainer

View File

@ -9,7 +9,7 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:

View File

@ -9,7 +9,7 @@ from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics
from .trainer import CustomSeq2SeqTrainer

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Dict
from ...data import templates
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config
@ -30,7 +30,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)