add logits processor
This commit is contained in:
parent
ed6161fa6a
commit
dca27b4412
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,12 +35,12 @@ def main():
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": 0.9,
|
"top_p": 0.7,
|
||||||
"top_k": 40,
|
"temperature": 0.95,
|
||||||
"temperature": 0.7,
|
|
||||||
"num_beams": 1,
|
"num_beams": 1,
|
||||||
"max_new_tokens": 256,
|
"max_new_tokens": 256,
|
||||||
"repetition_penalty": 1.5
|
"repetition_penalty": 1.5,
|
||||||
|
"logits_processor": get_logits_processor()
|
||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||||
|
|
|
@ -46,7 +46,8 @@ from .other import (
|
||||||
)
|
)
|
||||||
|
|
||||||
check_min_version("4.29.1")
|
check_min_version("4.29.1")
|
||||||
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||||
|
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||||
|
|
||||||
|
@ -84,8 +85,7 @@ def init_adapter(
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||||
if len(model_args.checkpoint_dir) > 1:
|
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||||
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
|
||||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
@ -154,8 +154,7 @@ def load_pretrained(
|
||||||
config_kwargs = {}
|
config_kwargs = {}
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||||
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
|
||||||
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
|
|
||||||
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
|
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
|
||||||
cuda = get_cuda_lib_handle()
|
cuda = get_cuda_lib_handle()
|
||||||
cc = get_compute_capability(cuda)
|
cc = get_compute_capability(cuda)
|
||||||
|
@ -179,7 +178,6 @@ def load_pretrained(
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False) # fix all model params
|
||||||
model = model.half() # cast all params to float16 for inference
|
|
||||||
|
|
||||||
if stage == "rm" or stage == "ppo": # add value head
|
if stage == "rm" or stage == "ppo": # add value head
|
||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
|
|
@ -49,6 +49,14 @@ class ModelArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model."}
|
metadata={"help": "The number of bits to quantize the model."}
|
||||||
)
|
)
|
||||||
|
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||||
|
default="nf4",
|
||||||
|
metadata={"help": "Quantization data type to use."}
|
||||||
|
)
|
||||||
|
double_quantization: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||||
|
)
|
||||||
checkpoint_dir: Optional[str] = field(
|
checkpoint_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
|
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
|
||||||
|
@ -206,14 +214,14 @@ class FinetuningArguments:
|
||||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
"""Save the content of this instance in JSON format inside `json_path`."""
|
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
f.write(json_string)
|
f.write(json_string)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_json(cls, json_path: str):
|
def load_from_json(cls, json_path: str):
|
||||||
"""Create an instance from the content of `json_path`."""
|
"""Creates an instance from the content of `json_path`."""
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
return cls(**json.loads(text))
|
return cls(**json.loads(text))
|
||||||
|
|
|
@ -9,7 +9,6 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
|
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer, LogCallback
|
from .peft_trainer import PeftTrainer, LogCallback
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
import mdtex2html
|
import mdtex2html
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"num_beams": 1,
|
"num_beams": 1,
|
||||||
"max_length": max_length,
|
"max_length": max_length,
|
||||||
"repetition_penalty": 1.0
|
"repetition_penalty": 1.5,
|
||||||
|
"logits_processor": get_logits_processor()
|
||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||||
|
|
Loading…
Reference in New Issue