This commit is contained in:
hiyouga 2024-02-06 15:23:08 +08:00
parent 85622ae757
commit ebf31b62eb
2 changed files with 15 additions and 1 deletions

View File

@ -94,6 +94,9 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> List[Response]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@ -123,6 +126,9 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@ -134,9 +140,11 @@ class ChatModel:
@torch.inference_mode()
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer(
batch_input,
padding=True,

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@ -307,7 +308,12 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))