fix #175
This commit is contained in:
parent
1e1358431d
commit
85c2210452
|
@ -3,7 +3,7 @@ transformers>=4.29.1
|
|||
datasets>=2.12.0
|
||||
accelerate>=0.19.0
|
||||
peft>=0.3.0
|
||||
trl>=0.4.4
|
||||
trl==0.4.4
|
||||
sentencepiece
|
||||
jieba
|
||||
rouge-chinese
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
@ -41,10 +42,10 @@ class ChatModel:
|
|||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
temperature=temperature if temperature else gen_kwargs["temperature"],
|
||||
top_p=top_p if top_p else gen_kwargs["top_p"],
|
||||
top_k=top_k if top_k else gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor()
|
||||
))
|
||||
|
||||
|
@ -58,6 +59,7 @@ class ChatModel:
|
|||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
|
@ -68,6 +70,7 @@ class ChatModel:
|
|||
response_length = len(outputs)
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
|
|
|
@ -28,7 +28,7 @@ check_min_version("4.29.1")
|
|||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||
require_version("trl==0.4.4", "To fix: pip install trl==0.4.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
|
|
@ -153,7 +153,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
|
|
|
@ -32,17 +32,40 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
|
||||
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
||||
else:
|
||||
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None
|
||||
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
||||
|
||||
return (loss, generated_tokens, labels)
|
||||
|
||||
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
|
||||
Should only be called when predict_with_generate=True.
|
||||
"""
|
||||
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
# If PAD token is not defined at least EOS token has to be defined
|
||||
pad_token_id = (
|
||||
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
||||
)
|
||||
else:
|
||||
if self.model.config.pad_token_id is not None:
|
||||
pad_token_id = self.model.config.pad_token_id
|
||||
else:
|
||||
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
|
|
Loading…
Reference in New Issue