From ebf31b62eb1b75399cff7c7542c45ac72f6f41dd Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 6 Feb 2024 15:23:08 +0800 Subject: [PATCH] fix #2438 --- src/llmtuner/chat/chat_model.py | 10 +++++++++- src/llmtuner/model/patcher.py | 6 ++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index cbc831b2..f8a2b14f 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -94,6 +94,9 @@ class ChatModel: tools: Optional[str] = None, **input_kwargs, ) -> List[Response]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs) generate_output = self.model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] @@ -123,6 +126,9 @@ class ChatModel: tools: Optional[str] = None, **input_kwargs, ) -> Generator[str, None, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer @@ -134,9 +140,11 @@ class ChatModel: @torch.inference_mode() def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + max_length = input_kwargs.pop("max_length", None) device = getattr(self.model.pretrained_model, "device", "cuda") - inputs = self.tokenizer( batch_input, padding=True, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 95b01f73..672d09d7 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch from datasets import load_dataset +from peft import PeftModel from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version @@ -307,7 +308,12 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_input_embeddings() + def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: + if isinstance(self.pretrained_model, PeftModel): + self.pretrained_model.create_or_update_model_card(output_dir) + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))