From dca27b4412e8e41cadcd623582222e1c216db78b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 3 Jun 2023 16:34:54 +0800 Subject: [PATCH] add logits processor --- src/cli_demo.py | 10 +++++----- src/utils/common.py | 10 ++++------ src/utils/config.py | 12 ++++++++++-- src/utils/ppo.py | 1 - src/web_demo.py | 5 +++-- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/cli_demo.py b/src/cli_demo.py index 3da88aa6..44262000 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -4,7 +4,7 @@ import torch -from utils import ModelArguments, FinetuningArguments, load_pretrained +from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor from transformers import HfArgumentParser @@ -35,12 +35,12 @@ def main(): input_ids = input_ids.to(model.device) gen_kwargs = { "do_sample": True, - "top_p": 0.9, - "top_k": 40, - "temperature": 0.7, + "top_p": 0.7, + "temperature": 0.95, "num_beams": 1, "max_new_tokens": 256, - "repetition_penalty": 1.5 + "repetition_penalty": 1.5, + "logits_processor": get_logits_processor() } with torch.no_grad(): generation_output = model.generate(input_ids=input_ids, **gen_kwargs) diff --git a/src/utils/common.py b/src/utils/common.py index 9009906b..65ae2931 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -46,7 +46,8 @@ from .other import ( ) check_min_version("4.29.1") -require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") +require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") +require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") @@ -84,8 +85,7 @@ def init_adapter( param.data = param.data.to(torch.float32) if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: - if len(model_args.checkpoint_dir) > 1: - logger.warning("Only LoRA tuning accepts multiple checkpoints.") + assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods if finetuning_args.finetuning_type == "lora": @@ -154,8 +154,7 @@ def load_pretrained( config_kwargs = {} if model_args.quantization_bit is not None: assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." - - require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1") from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible cuda = get_cuda_lib_handle() cc = get_compute_capability(cuda) @@ -179,7 +178,6 @@ def load_pretrained( if not is_trainable: model.requires_grad_(False) # fix all model params - model = model.half() # cast all params to float16 for inference if stage == "rm" or stage == "ppo": # add value head model = AutoModelForCausalLMWithValueHead.from_pretrained(model) diff --git a/src/utils/config.py b/src/utils/config.py index 22b66fb9..98d5907e 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -49,6 +49,14 @@ class ModelArguments: default=None, metadata={"help": "The number of bits to quantize the model."} ) + quantization_type: Optional[Literal["fp4", "nf4"]] = field( + default="nf4", + metadata={"help": "Quantization data type to use."} + ) + double_quantization: Optional[bool] = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) checkpoint_dir: Optional[str] = field( default=None, metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."} @@ -206,14 +214,14 @@ class FinetuningArguments: assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." def save_to_json(self, json_path: str): - """Save the content of this instance in JSON format inside `json_path`.""" + """Saves the content of this instance in JSON format inside `json_path`.""" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" with open(json_path, "w", encoding="utf-8") as f: f.write(json_string) @classmethod def load_from_json(cls, json_path: str): - """Create an instance from the content of `json_path`.""" + """Creates an instance from the content of `json_path`.""" with open(json_path, "r", encoding="utf-8") as f: text = f.read() return cls(**json.loads(text)) diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 5e754e48..701d4b4b 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -9,7 +9,6 @@ from transformers.modeling_utils import PreTrainedModel from trl import PPOTrainer, AutoModelForCausalLMWithValueHead from trl.core import LengthSampler -from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits from .peft_trainer import PeftTrainer, LogCallback diff --git a/src/web_demo.py b/src/web_demo.py index 83ccdf9a..7445d0e1 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -7,7 +7,7 @@ import torch import mdtex2html import gradio as gr -from utils import ModelArguments, FinetuningArguments, load_pretrained +from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor from transformers import HfArgumentParser from transformers.utils.versions import require_version @@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history): "temperature": temperature, "num_beams": 1, "max_length": max_length, - "repetition_penalty": 1.0 + "repetition_penalty": 1.5, + "logits_processor": get_logits_processor() } with torch.no_grad(): generation_output = model.generate(input_ids=input_ids, **gen_kwargs)