add logits processor

This commit is contained in:
hiyouga 2023-06-03 16:34:54 +08:00
parent ed6161fa6a
commit dca27b4412
5 changed files with 22 additions and 16 deletions

View File

@ -4,7 +4,7 @@
import torch import torch
from utils import ModelArguments, FinetuningArguments, load_pretrained from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
from transformers import HfArgumentParser from transformers import HfArgumentParser
@ -35,12 +35,12 @@ def main():
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,
"top_p": 0.9, "top_p": 0.7,
"top_k": 40, "temperature": 0.95,
"temperature": 0.7,
"num_beams": 1, "num_beams": 1,
"max_new_tokens": 256, "max_new_tokens": 256,
"repetition_penalty": 1.5 "repetition_penalty": 1.5,
"logits_processor": get_logits_processor()
} }
with torch.no_grad(): with torch.no_grad():
generation_output = model.generate(input_ids=input_ids, **gen_kwargs) generation_output = model.generate(input_ids=input_ids, **gen_kwargs)

View File

@ -46,7 +46,8 @@ from .other import (
) )
check_min_version("4.29.1") check_min_version("4.29.1")
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
@ -84,8 +85,7 @@ def init_adapter(
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
if len(model_args.checkpoint_dir) > 1: assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
@ -154,8 +154,7 @@ def load_pretrained(
config_kwargs = {} config_kwargs = {}
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
cuda = get_cuda_lib_handle() cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda) cc = get_compute_capability(cuda)
@ -179,7 +178,6 @@ def load_pretrained(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() # cast all params to float16 for inference
if stage == "rm" or stage == "ppo": # add value head if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

View File

@ -49,6 +49,14 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "The number of bits to quantize the model."} metadata={"help": "The number of bits to quantize the model."}
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."} metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
@ -206,14 +214,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`.""" """Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f: with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string) f.write(json_string)
@classmethod @classmethod
def load_from_json(cls, json_path: str): def load_from_json(cls, json_path: str):
"""Create an instance from the content of `json_path`.""" """Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))

View File

@ -9,7 +9,6 @@ from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler from trl.core import LengthSampler
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
from .peft_trainer import PeftTrainer, LogCallback from .peft_trainer import PeftTrainer, LogCallback

View File

@ -7,7 +7,7 @@ import torch
import mdtex2html import mdtex2html
import gradio as gr import gradio as gr
from utils import ModelArguments, FinetuningArguments, load_pretrained from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
"temperature": temperature, "temperature": temperature,
"num_beams": 1, "num_beams": 1,
"max_length": max_length, "max_length": max_length,
"repetition_penalty": 1.0 "repetition_penalty": 1.5,
"logits_processor": get_logits_processor()
} }
with torch.no_grad(): with torch.no_grad():
generation_output = model.generate(input_ids=input_ids, **gen_kwargs) generation_output = model.generate(input_ids=input_ids, **gen_kwargs)