add logits processor
This commit is contained in:
parent
ed6161fa6a
commit
dca27b4412
|
@ -4,7 +4,7 @@
|
|||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
|
@ -35,12 +35,12 @@ def main():
|
|||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.95,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": 256,
|
||||
"repetition_penalty": 1.5
|
||||
"repetition_penalty": 1.5,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||
|
|
|
@ -46,7 +46,8 @@ from .other import (
|
|||
)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||
|
||||
|
@ -84,8 +85,7 @@ def init_adapter(
|
|||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
if len(model_args.checkpoint_dir) > 1:
|
||||
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
|
@ -154,8 +154,7 @@ def load_pretrained(
|
|||
config_kwargs = {}
|
||||
if model_args.quantization_bit is not None:
|
||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||
|
||||
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
|
||||
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
|
@ -179,7 +178,6 @@ def load_pretrained(
|
|||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() # cast all params to float16 for inference
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
|
|
@ -49,6 +49,14 @@ class ModelArguments:
|
|||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
)
|
||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use."}
|
||||
)
|
||||
double_quantization: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
|
||||
|
@ -206,14 +214,14 @@ class FinetuningArguments:
|
|||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Save the content of this instance in JSON format inside `json_path`."""
|
||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
"""Create an instance from the content of `json_path`."""
|
||||
"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
|
|
@ -9,7 +9,6 @@ from transformers.modeling_utils import PreTrainedModel
|
|||
|
||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
|
||||
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
|
|||
"temperature": temperature,
|
||||
"num_beams": 1,
|
||||
"max_length": max_length,
|
||||
"repetition_penalty": 1.0
|
||||
"repetition_penalty": 1.5,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||
|
|
Loading…
Reference in New Issue