support resize embeddings #1786

This commit is contained in:
hiyouga 2023-12-11 17:50:02 +08:00
parent 9ce1b0e2f2
commit 64744dde89
2 changed files with 16 additions and 1 deletions

View File

@ -28,7 +28,7 @@ from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@ -185,6 +185,9 @@ def load_model_and_tokenizer(
**config_kwargs
)
# Resize token embeddings
resize_embedding_layer(model, tokenizer)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)

View File

@ -11,6 +11,7 @@ from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
@ -181,3 +182,14 @@ def prepare_model_for_training(
output_layer.register_forward_hook(fp32_forward_post_hook)
return model
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
old_vocab_size = model.get_input_embeddings().weight.size(0)
new_vocab_size = len(tokenizer)
if new_vocab_size != old_vocab_size:
model.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64)
logger.info("Resized embedding tokens from {} to {}.".format(old_vocab_size, new_vocab_size))