forked from p04798526/LLaMA-Factory-Mirror
refactor pissa, improve llamaboard
This commit is contained in:
parent
ef38daa0a4
commit
8baf3b22b0
|
@ -1,4 +1,7 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,15 +17,11 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
is_safetensors_available,
|
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
|
@ -31,15 +30,9 @@ from transformers.utils import (
|
||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
from safetensors import safe_open
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
try:
|
try:
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||||
|
@ -48,8 +41,6 @@ except Exception:
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +90,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||||
num_params = param.ds_numel
|
num_params = param.ds_numel
|
||||||
|
|
||||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
||||||
if param.__class__.__name__ == "Params4bit":
|
if param.__class__.__name__ == "Params4bit":
|
||||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||||
num_bytes = param.quant_storage.itemsize
|
num_bytes = param.quant_storage.itemsize
|
||||||
|
@ -117,51 +108,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def fix_valuehead_checkpoint(
|
|
||||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
The model is already unwrapped.
|
|
||||||
|
|
||||||
There are three cases:
|
|
||||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
|
||||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
|
||||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
|
||||||
|
|
||||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
|
||||||
"""
|
|
||||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
|
||||||
return
|
|
||||||
|
|
||||||
if safe_serialization:
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
|
||||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
|
||||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
|
||||||
else:
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
|
||||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
|
||||||
|
|
||||||
decoder_state_dict = {}
|
|
||||||
v_head_state_dict = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.startswith("v_head."):
|
|
||||||
v_head_state_dict[name] = param
|
|
||||||
else:
|
|
||||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
|
||||||
model.pretrained_model.save_pretrained(
|
|
||||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
|
|
||||||
if safe_serialization:
|
|
||||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
|
||||||
else:
|
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
|
||||||
|
|
||||||
logger.info("Value head model saved at: {}".format(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
def get_current_device() -> torch.device:
|
||||||
r"""
|
r"""
|
||||||
Gets the current available device.
|
Gets the current available device.
|
||||||
|
@ -201,7 +147,7 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
r"""
|
r"""
|
||||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||||
"""
|
"""
|
||||||
|
@ -220,7 +166,7 @@ def is_gpu_or_npu_available() -> bool:
|
||||||
return is_torch_npu_available() or is_torch_cuda_available()
|
return is_torch_npu_available() or is_torch_cuda_available()
|
||||||
|
|
||||||
|
|
||||||
def has_tokenized_data(path: os.PathLike) -> bool:
|
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||||
r"""
|
r"""
|
||||||
Checks if the path has a tokenized dataset.
|
Checks if the path has a tokenized dataset.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -379,10 +379,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||||
|
|
||||||
if self.pissa_convert and self.finetuning_type != "lora":
|
if self.pissa_init and self.finetuning_type != "lora":
|
||||||
raise ValueError("`pissa_convert` is only valid for LoRA training.")
|
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||||
|
|
||||||
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model):
|
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||||
|
|
||||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||||
|
|
|
@ -83,9 +83,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||||
|
|
||||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
|
||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
@ -186,6 +183,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||||
|
|
||||||
|
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
|
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
|
@ -195,6 +195,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||||
raise ValueError("Cannot use device map for quantized models in training.")
|
raise ValueError("Cannot use device map for quantized models in training.")
|
||||||
|
|
||||||
|
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if finetuning_args.pure_bf16:
|
if finetuning_args.pure_bf16:
|
||||||
if not is_torch_bf16_gpu_available():
|
if not is_torch_bf16_gpu_available():
|
||||||
raise ValueError("This device does not support `pure_bf16`.")
|
raise ValueError("This device does not support `pure_bf16`.")
|
||||||
|
@ -224,6 +227,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if model_args.visual_inputs and data_args.packing:
|
if model_args.visual_inputs and data_args.packing:
|
||||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
||||||
|
|
||||||
|
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -22,22 +25,78 @@ from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import TrainerCallback
|
from peft import PeftModel
|
||||||
|
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||||
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_safetensors_available,
|
||||||
|
)
|
||||||
|
|
||||||
from .constants import TRAINER_LOG
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from .logging import LoggerHandler, get_logger
|
from ..extras.logging import LoggerHandler, get_logger
|
||||||
from .misc import fix_valuehead_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_valuehead_checkpoint(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
The model is already unwrapped.
|
||||||
|
|
||||||
|
There are three cases:
|
||||||
|
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||||
|
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
|
||||||
|
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||||
|
"""
|
||||||
|
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||||
|
return
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
else:
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||||
|
|
||||||
|
decoder_state_dict = {}
|
||||||
|
v_head_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("v_head."):
|
||||||
|
v_head_state_dict[name] = param
|
||||||
|
else:
|
||||||
|
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
|
model.pretrained_model.save_pretrained(
|
||||||
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
class FixValueHeadModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
@ -51,8 +110,70 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveProcessorCallback(TrainerCallback):
|
||||||
|
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||||
|
r"""
|
||||||
|
Initializes a callback for saving the processor.
|
||||||
|
"""
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the end of training.
|
||||||
|
"""
|
||||||
|
if args.should_save:
|
||||||
|
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class PissaConvertCallback(TrainerCallback):
|
||||||
|
r"""
|
||||||
|
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the beginning of training.
|
||||||
|
"""
|
||||||
|
if args.should_save:
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||||
|
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||||
|
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||||
|
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
||||||
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
|
|
||||||
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the end of training.
|
||||||
|
"""
|
||||||
|
if args.should_save:
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||||
|
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||||
|
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||||
|
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
|
||||||
|
# 1. save a pissa backup with init_lora_weights: True
|
||||||
|
# 2. save a converted lora with init_lora_weights: pissa
|
||||||
|
# 3. load the pissa backup with init_lora_weights: True
|
||||||
|
# 4. delete the initial adapter and change init_lora_weights to pissa
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||||
|
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||||
|
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
|
||||||
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
|
model.save_pretrained(
|
||||||
|
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
|
||||||
|
)
|
||||||
|
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||||
|
model.set_adapter("default")
|
||||||
|
model.delete_adapter("pissa_init")
|
||||||
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
def __init__(self, output_dir: str) -> None:
|
def __init__(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Initializes a callback for logging training and evaluation status.
|
Initializes a callback for logging training and evaluation status.
|
||||||
"""
|
"""
|
||||||
|
@ -70,7 +191,7 @@ class LogCallback(TrainerCallback):
|
||||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||||
if self.webui_mode:
|
if self.webui_mode:
|
||||||
signal.signal(signal.SIGABRT, self._set_abort)
|
signal.signal(signal.SIGABRT, self._set_abort)
|
||||||
self.logger_handler = LoggerHandler(output_dir)
|
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||||
logging.root.addHandler(self.logger_handler)
|
logging.root.addHandler(self.logger_handler)
|
||||||
transformers.logging.add_handler(self.logger_handler)
|
transformers.logging.add_handler(self.logger_handler)
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
@ -29,7 +28,8 @@ from trl import DPOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -54,7 +54,6 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.processor = processor
|
|
||||||
self.reference_free = False
|
self.reference_free = False
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
|
@ -92,14 +91,17 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
self.ref_model.eval()
|
self.ref_model.eval()
|
||||||
|
|
||||||
|
if processor is not None:
|
||||||
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.pissa_convert:
|
if finetuning_args.pissa_convert:
|
||||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
self.callback_handler.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.callback_handler.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -112,15 +114,6 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
|
||||||
super()._save(output_dir, state_dict)
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
if self.finetuning_args.pissa_convert:
|
|
||||||
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
|
|
||||||
|
|
||||||
if self.processor is not None:
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
||||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||||
r"""
|
r"""
|
||||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||||
|
|
|
@ -27,6 +27,7 @@ from trl import KTOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.processor = processor
|
|
||||||
self.reference_free = False
|
self.reference_free = False
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
|
@ -90,11 +90,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
self.ref_model.eval()
|
self.ref_model.eval()
|
||||||
|
|
||||||
|
if processor is not None:
|
||||||
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.callback_handler.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -113,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
"""
|
"""
|
||||||
return Trainer._get_train_sampler(self)
|
return Trainer._get_train_sampler(self)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
|
||||||
super()._save(output_dir, state_dict)
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
if self.processor is not None:
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||||
|
|
|
@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
from transformers.trainer_callback import CallbackHandler
|
||||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
|
@ -34,9 +35,9 @@ from trl import PPOConfig, PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
from trl.models.utils import unwrap_model_for_generation
|
from trl.models.utils import unwrap_model_for_generation
|
||||||
|
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||||
|
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||||
|
|
||||||
|
@ -131,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.reward_model = reward_model
|
self.reward_model = reward_model
|
||||||
self.current_device = get_current_device() # patch for deepspeed training
|
self.current_device = get_current_device() # patch for deepspeed training
|
||||||
self.processor = processor
|
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
@ -143,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.callback_handler = CallbackHandler(
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
[callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
@ -165,11 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
else:
|
else:
|
||||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
self.add_callback(FixValueHeadModelCallback)
|
||||||
|
|
||||||
|
if processor is not None:
|
||||||
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.callback_handler.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -219,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
reward_meter = AverageMeter()
|
reward_meter = AverageMeter()
|
||||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||||
try:
|
try:
|
||||||
|
@ -257,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
logger.warning("Failed to save stats due to unknown errors.")
|
logger.warning("Failed to save stats due to unknown errors.")
|
||||||
|
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||||
logs = dict(
|
logs = dict(
|
||||||
|
@ -269,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
tqdm.write(str(logs))
|
tqdm.write(str(logs))
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
self.state.log_history.append(logs)
|
self.state.log_history.append(logs)
|
||||||
self.log_callback.on_log(self.args, self.state, self.control)
|
self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||||
loss_meter.reset()
|
loss_meter.reset()
|
||||||
reward_meter.reset()
|
reward_meter.reset()
|
||||||
|
|
||||||
|
@ -277,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self.save_model(
|
self.save_model(
|
||||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||||
)
|
)
|
||||||
self.save_callback.on_save(
|
self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||||
self.save_callback.on_train_end(
|
|
||||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_optimizer(
|
def create_optimizer(
|
||||||
self,
|
self,
|
||||||
|
@ -505,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
|
|
||||||
elif self.args.should_save:
|
elif self.args.should_save:
|
||||||
self._save(output_dir)
|
self._save(output_dir)
|
||||||
|
|
||||||
if self.processor is not None and self.args.should_save:
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
|
@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, List, Optional
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
from ...data import get_dataset
|
from ...data import get_dataset
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback
|
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
|
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
|
||||||
from ..trainer_utils import create_ref_model, create_reward_model
|
from ..trainer_utils import create_ref_model, create_reward_model
|
||||||
from .trainer import CustomPPOTrainer
|
from .trainer import CustomPPOTrainer
|
||||||
|
|
||||||
|
@ -75,6 +74,7 @@ def run_ppo(
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
|
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|
|
@ -12,14 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -42,16 +42,18 @@ class CustomTrainer(Trainer):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.processor = processor
|
|
||||||
|
if processor is not None:
|
||||||
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.pissa_convert:
|
if finetuning_args.pissa_convert:
|
||||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
self.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.callback_handler.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -63,12 +65,3 @@ class CustomTrainer(Trainer):
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
|
||||||
super()._save(output_dir, state_dict)
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
if self.finetuning_args.pissa_convert:
|
|
||||||
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
|
|
||||||
|
|
||||||
if self.processor is not None:
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,13 +70,20 @@ class PairwiseTrainer(Trainer):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.processor = processor
|
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
|
self.add_callback(FixValueHeadModelCallback)
|
||||||
|
|
||||||
|
if processor is not None:
|
||||||
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
|
if finetuning_args.pissa_convert:
|
||||||
|
self.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.callback_handler.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -88,12 +96,6 @@ class PairwiseTrainer(Trainer):
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
|
||||||
super()._save(output_dir, state_dict)
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
if self.processor is not None:
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
@ -164,4 +166,5 @@ class PairwiseTrainer(Trainer):
|
||||||
res: List[str] = []
|
res: List[str] = []
|
||||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||||
|
|
||||||
writer.write("\n".join(res))
|
writer.write("\n".join(res))
|
||||||
|
|
|
@ -40,10 +40,9 @@
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback
|
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
|
from ..callbacks import fix_valuehead_checkpoint
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push
|
||||||
from .metric import compute_accuracy
|
from .metric import compute_accuracy
|
||||||
from .trainer import PairwiseTrainer
|
from .trainer import PairwiseTrainer
|
||||||
|
@ -77,7 +76,7 @@ def run_rm(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
callbacks=callbacks,
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
|
@ -89,6 +88,7 @@ def run_rm(
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
|
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
|
|
|
@ -26,7 +26,8 @@ from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -50,19 +51,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.processor = processor
|
|
||||||
|
if processor is not None:
|
||||||
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.pissa_convert:
|
if finetuning_args.pissa_convert:
|
||||||
if self.is_deepspeed_enabled:
|
self.add_callback(PissaConvertCallback)
|
||||||
self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
||||||
self.deepspeed = self._wrap_model(self.model_wrapped)
|
|
||||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.callback_handler.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -75,15 +75,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
|
||||||
super()._save(output_dir, state_dict)
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
if self.finetuning_args.pissa_convert:
|
|
||||||
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
|
|
||||||
|
|
||||||
if self.processor is not None:
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
model: "torch.nn.Module",
|
model: "torch.nn.Module",
|
||||||
|
|
|
@ -17,11 +17,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
@ -40,7 +38,6 @@ if is_galore_available():
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from accelerate import Accelerator
|
|
||||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
@ -175,51 +172,6 @@ def create_reward_model(
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
def convert_pissa_adapter(
|
|
||||||
output_dir: str,
|
|
||||||
state_dict: Dict[str, "torch.Tensor"],
|
|
||||||
accelerator: "Accelerator",
|
|
||||||
model: "PreTrainedModel",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
Converts the PiSSA adapter to a LoRA adapter.
|
|
||||||
"""
|
|
||||||
pissa_init_dir = os.path.join(training_args.output_dir, "pissa_init")
|
|
||||||
pissa_backup_dir = os.path.join(output_dir, "pissa_backup")
|
|
||||||
if output_dir == pissa_init_dir:
|
|
||||||
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
|
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
|
||||||
if isinstance(unwrapped_model, PeftModel):
|
|
||||||
init_lora_weights = getattr(unwrapped_model.peft_config["default"], "init_lora_weights")
|
|
||||||
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", True)
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
output_dir,
|
|
||||||
state_dict=state_dict,
|
|
||||||
safe_serialization=training_args.save_safetensors,
|
|
||||||
)
|
|
||||||
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
|
||||||
|
|
||||||
elif output_dir == training_args.output_dir: # at the end of training
|
|
||||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(output_dir))
|
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
|
||||||
if isinstance(unwrapped_model, PeftModel): # backup the pissa adapter for further use
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
pissa_backup_dir,
|
|
||||||
state_dict=state_dict,
|
|
||||||
safe_serialization=training_args.save_safetensors,
|
|
||||||
)
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
output_dir,
|
|
||||||
state_dict=state_dict,
|
|
||||||
safe_serialization=training_args.save_safetensors,
|
|
||||||
convert_pissa_to_lora=pissa_init_dir,
|
|
||||||
)
|
|
||||||
# TODO: the model is applied pissa again unexpectedly
|
|
||||||
unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
|
||||||
unwrapped_model.set_adapter("default")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||||
|
|
|
@ -20,11 +20,11 @@ import torch
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.callbacks import LogCallback
|
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..hparams import get_infer_args, get_train_args
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
|
from .callbacks import LogCallback
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
from .kto import run_kto
|
from .kto import run_kto
|
||||||
from .ppo import run_ppo
|
from .ppo import run_ppo
|
||||||
|
@ -41,8 +41,8 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||||
|
callbacks.append(LogCallback())
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||||
callbacks.append(LogCallback(training_args.output_dir))
|
|
||||||
|
|
||||||
if finetuning_args.stage == "pt":
|
if finetuning_args.stage == "pt":
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
|
|
@ -310,6 +310,7 @@ class Runner:
|
||||||
|
|
||||||
env = deepcopy(os.environ)
|
env = deepcopy(os.environ)
|
||||||
env["LLAMABOARD_ENABLED"] = "1"
|
env["LLAMABOARD_ENABLED"] = "1"
|
||||||
|
env["LLAMABOARD_WORKDIR"] = args["output_dir"]
|
||||||
if args.get("deepspeed", None) is not None:
|
if args.get("deepspeed", None) is not None:
|
||||||
env["FORCE_TORCHRUN"] = "1"
|
env["FORCE_TORCHRUN"] = "1"
|
||||||
|
|
||||||
|
|
|
@ -38,12 +38,15 @@ def abort_process(pid: int) -> None:
|
||||||
r"""
|
r"""
|
||||||
Aborts the processes recursively in a bottom-up way.
|
Aborts the processes recursively in a bottom-up way.
|
||||||
"""
|
"""
|
||||||
children = psutil.Process(pid).children()
|
try:
|
||||||
if children:
|
children = psutil.Process(pid).children()
|
||||||
for child in children:
|
if children:
|
||||||
abort_process(child.pid)
|
for child in children:
|
||||||
|
abort_process(child.pid)
|
||||||
|
|
||||||
os.kill(pid, signal.SIGABRT)
|
os.kill(pid, signal.SIGABRT)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||||
|
|
Loading…
Reference in New Issue