fix #2438
This commit is contained in:
parent
85622ae757
commit
ebf31b62eb
|
@ -94,6 +94,9 @@ class ChatModel:
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `chat`.")
|
||||||
|
|
||||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
generate_output = self.model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
|
@ -123,6 +126,9 @@ class ChatModel:
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `stream_chat`.")
|
||||||
|
|
||||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
|
@ -134,9 +140,11 @@ class ChatModel:
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
||||||
|
if self.can_generate:
|
||||||
|
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||||
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
batch_input,
|
batch_input,
|
||||||
padding=True,
|
padding=True,
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from peft import PeftModel
|
||||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
@ -307,7 +308,12 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||||
return self.pretrained_model.get_input_embeddings()
|
return self.pretrained_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||||
|
if isinstance(self.pretrained_model, PeftModel):
|
||||||
|
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||||
|
|
||||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||||
|
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||||
|
|
Loading…
Reference in New Issue