add logits processor

This commit is contained in:
hiyouga 2023-06-03 16:34:54 +08:00
parent ed6161fa6a
commit dca27b4412
5 changed files with 22 additions and 16 deletions

View File

@ -4,7 +4,7 @@
import torch
from utils import ModelArguments, FinetuningArguments, load_pretrained
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
from transformers import HfArgumentParser
@ -35,12 +35,12 @@ def main():
input_ids = input_ids.to(model.device)
gen_kwargs = {
"do_sample": True,
"top_p": 0.9,
"top_k": 40,
"temperature": 0.7,
"top_p": 0.7,
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 256,
"repetition_penalty": 1.5
"repetition_penalty": 1.5,
"logits_processor": get_logits_processor()
}
with torch.no_grad():
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)

View File

@ -46,7 +46,8 @@ from .other import (
)
check_min_version("4.29.1")
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
@ -84,8 +85,7 @@ def init_adapter(
param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
if len(model_args.checkpoint_dir) > 1:
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
if finetuning_args.finetuning_type == "lora":
@ -154,8 +154,7 @@ def load_pretrained(
config_kwargs = {}
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
@ -179,7 +178,6 @@ def load_pretrained(
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() # cast all params to float16 for inference
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

View File

@ -49,6 +49,14 @@ class ModelArguments:
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
@ -206,14 +214,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`."""
"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Create an instance from the content of `json_path`."""
"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@ -9,7 +9,6 @@ from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
from .peft_trainer import PeftTrainer, LogCallback

View File

@ -7,7 +7,7 @@ import torch
import mdtex2html
import gradio as gr
from utils import ModelArguments, FinetuningArguments, load_pretrained
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
from transformers import HfArgumentParser
from transformers.utils.versions import require_version
@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
"temperature": temperature,
"num_beams": 1,
"max_length": max_length,
"repetition_penalty": 1.0
"repetition_penalty": 1.5,
"logits_processor": get_logits_processor()
}
with torch.no_grad():
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)