refactor pissa, improve llamaboard

This commit is contained in:
hiyouga 2024-06-28 01:04:24 +08:00
parent ef38daa0a4
commit 8baf3b22b0
16 changed files with 219 additions and 216 deletions

View File

@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,15 +17,11 @@
import gc import gc
import os import os
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Tuple
import torch import torch
from peft import PeftModel from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import ( from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
is_torch_bf16_gpu_available, is_torch_bf16_gpu_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mps_available, is_torch_mps_available,
@ -31,15 +30,9 @@ from transformers.utils import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger from .logging import get_logger
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try: try:
_is_bf16_available = is_torch_bf16_gpu_available() _is_bf16_available = is_torch_bf16_gpu_available()
@ -48,8 +41,6 @@ except Exception:
if TYPE_CHECKING: if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments from ..hparams import ModelArguments
@ -99,7 +90,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if num_params == 0 and hasattr(param, "ds_numel"): if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if param.__class__.__name__ == "Params4bit": if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize num_bytes = param.quant_storage.itemsize
@ -117,51 +108,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param return trainable_params, all_param
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
def get_current_device() -> torch.device: def get_current_device() -> torch.device:
r""" r"""
Gets the current available device. Gets the current available device.
@ -201,7 +147,7 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r""" r"""
Infers the optimal dtype according to the model_dtype and device compatibility. Infers the optimal dtype according to the model_dtype and device compatibility.
""" """
@ -220,7 +166,7 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available() return is_torch_npu_available() or is_torch_cuda_available()
def has_tokenized_data(path: os.PathLike) -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:
r""" r"""
Checks if the path has a tokenized dataset. Checks if the path has a tokenized dataset.
""" """

View File

@ -379,10 +379,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.pissa_convert and self.finetuning_type != "lora": if self.pissa_init and self.finetuning_type != "lora":
raise ValueError("`pissa_convert` is only valid for LoRA training.") raise ValueError("`pissa_init` is only valid for LoRA training.")
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model): if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.") raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full": if self.train_mm_proj_only and self.finetuning_type != "full":

View File

@ -83,9 +83,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.") raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
@ -186,6 +183,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
if training_args.max_steps == -1 and data_args.streaming: if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.") raise ValueError("Please specify `max_steps` in streaming mode.")
@ -195,6 +195,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and model_args.quantization_device_map == "auto": if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.") raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16: if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available(): if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.") raise ValueError("This device does not support `pure_bf16`.")
@ -224,6 +227,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.visual_inputs and data_args.packing: if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.") raise ValueError("Cannot use packing in MLLM fine-tuning.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)

View File

@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,22 +25,78 @@ from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch
import transformers import transformers
from transformers import TrainerCallback from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
)
from .constants import TRAINER_LOG from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import LoggerHandler, get_logger from ..extras.logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__) logger = get_logger(__name__)
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
@ -51,8 +110,70 @@ class FixValueHeadModelCallback(TrainerCallback):
) )
class SaveProcessorCallback(TrainerCallback):
def __init__(self, processor: "ProcessorMixin") -> None:
r"""
Initializes a callback for saving the processor.
"""
self.processor = processor
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback):
r"""
Initializes a callback for converting the PiSSA adapter to a normal one.
"""
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True)
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True
# 4. delete the initial adapter and change init_lora_weights to pissa
if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True)
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default")
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, output_dir: str) -> None: def __init__(self) -> None:
r""" r"""
Initializes a callback for logging training and evaluation status. Initializes a callback for logging training and evaluation status.
""" """
@ -70,7 +191,7 @@ class LogCallback(TrainerCallback):
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir) self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.root.addHandler(self.logger_handler) logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler)

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
@ -29,7 +28,8 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler, get_batch_logps from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING: if TYPE_CHECKING:
@ -54,7 +54,6 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.processor = processor
self.reference_free = False self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation self.generate_during_eval = False # disable at evaluation
@ -92,14 +91,17 @@ class CustomDPOTrainer(DPOTrainer):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
self.save_model(os.path.join(self.args.output_dir, "pissa_init")) self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
@ -112,15 +114,6 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.finetuning_args.pissa_convert:
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.

View File

@ -27,6 +27,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
@ -53,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.processor = processor
self.reference_free = False self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation self.generate_during_eval = False # disable at evaluation
@ -90,11 +90,14 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
@ -113,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer):
""" """
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:

View File

@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
@ -34,9 +35,9 @@ from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@ -131,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.reward_model = reward_model self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training self.current_device = get_current_device() # patch for deepspeed training
self.processor = processor
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
@ -143,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl() self.control = TrainerControl()
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.callback_handler = CallbackHandler(
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback) [callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
@ -165,11 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else: else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
self.add_callback(FixValueHeadModelCallback)
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r""" r"""
@ -219,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
loss_meter = AverageMeter() loss_meter = AverageMeter()
reward_meter = AverageMeter() reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control) self.callback_handler.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try: try:
@ -257,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.warning("Failed to save stats due to unknown errors.") logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1 self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control) self.callback_handler.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0: if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict( logs = dict(
@ -269,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
tqdm.write(str(logs)) tqdm.write(str(logs))
logs["step"] = step logs["step"] = step
self.state.log_history.append(logs) self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control) self.callback_handler.on_log(self.args, self.state, self.control, logs)
loss_meter.reset() loss_meter.reset()
reward_meter.reset() reward_meter.reset()
@ -277,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.save_model( self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)) os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
) )
self.save_callback.on_save( self.callback_handler.on_save(self.args, self.state, self.control)
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
if self.control.should_epoch_stop or self.control.should_training_stop: if self.control.should_epoch_stop or self.control.should_training_stop:
break break
self.log_callback.on_train_end(self.args, self.state, self.control) self.callback_handler.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
def create_optimizer( def create_optimizer(
self, self,
@ -505,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
elif self.args.should_save: elif self.args.should_save:
self._save(output_dir) self._save(output_dir)
if self.processor is not None and self.args.should_save:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)

View File

@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from ...data import get_dataset from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
from ..trainer_utils import create_ref_model, create_reward_model from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer from .trainer import CustomPPOTrainer
@ -75,6 +74,7 @@ def run_ppo(
ppo_trainer.save_model() ppo_trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -42,16 +42,18 @@ class CustomTrainer(Trainer):
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.processor = processor
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
self.save_model(os.path.join(self.args.output_dir, "pissa_init")) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
@ -63,12 +65,3 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.finetuning_args.pissa_convert:
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir)

View File

@ -46,6 +46,7 @@ import torch
from transformers import Trainer from transformers import Trainer
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
@ -69,13 +70,20 @@ class PairwiseTrainer(Trainer):
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.processor = processor
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
self.add_callback(FixValueHeadModelCallback)
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
@ -88,12 +96,6 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -164,4 +166,5 @@ class PairwiseTrainer(Trainer):
res: List[str] = [] res: List[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores): for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res)) writer.write("\n".join(res))

View File

@ -40,10 +40,9 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_modelcard_and_push from ..trainer_utils import create_modelcard_and_push
from .metric import compute_accuracy from .metric import compute_accuracy
from .trainer import PairwiseTrainer from .trainer import PairwiseTrainer
@ -77,7 +76,7 @@ def run_rm(
args=training_args, args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks + [FixValueHeadModelCallback()], callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**tokenizer_module, **tokenizer_module,
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
@ -89,6 +88,7 @@ def run_rm(
trainer.save_model() trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()

View File

@ -26,7 +26,8 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,19 +51,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.processor = processor
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
if self.is_deepspeed_enabled: self.add_callback(PissaConvertCallback)
self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
self.deepspeed = self._wrap_model(self.model_wrapped)
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
@ -75,15 +75,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.finetuning_args.pissa_convert:
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def prediction_step( def prediction_step(
self, self,
model: "torch.nn.Module", model: "torch.nn.Module",

View File

@ -17,11 +17,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from peft import PeftModel
from transformers import Trainer from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
@ -40,7 +38,6 @@ if is_galore_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -175,51 +172,6 @@ def create_reward_model(
return reward_model return reward_model
def convert_pissa_adapter(
output_dir: str,
state_dict: Dict[str, "torch.Tensor"],
accelerator: "Accelerator",
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
) -> None:
r"""
Converts the PiSSA adapter to a LoRA adapter.
"""
pissa_init_dir = os.path.join(training_args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(output_dir, "pissa_backup")
if output_dir == pissa_init_dir:
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
unwrapped_model = accelerator.unwrap_model(model)
if isinstance(unwrapped_model, PeftModel):
init_lora_weights = getattr(unwrapped_model.peft_config["default"], "init_lora_weights")
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", True)
unwrapped_model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=training_args.save_safetensors,
)
setattr(unwrapped_model.peft_config["default"], "init_lora_weights", init_lora_weights)
elif output_dir == training_args.output_dir: # at the end of training
logger.info("Converted PiSSA adapter will be saved at: {}.".format(output_dir))
unwrapped_model = accelerator.unwrap_model(model)
if isinstance(unwrapped_model, PeftModel): # backup the pissa adapter for further use
unwrapped_model.save_pretrained(
pissa_backup_dir,
state_dict=state_dict,
safe_serialization=training_args.save_safetensors,
)
unwrapped_model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=training_args.save_safetensors,
convert_pissa_to_lora=pissa_init_dir,
)
# TODO: the model is applied pissa again unexpectedly
unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
unwrapped_model.set_adapter("default")
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r""" r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)

View File

@ -20,11 +20,11 @@ import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.callbacks import LogCallback
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback
from .dpo import run_dpo from .dpo import run_dpo
from .kto import run_kto from .kto import run_kto
from .ppo import run_ppo from .ppo import run_ppo
@ -41,8 +41,8 @@ logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback(training_args.output_dir))
if finetuning_args.stage == "pt": if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)

View File

@ -310,6 +310,7 @@ class Runner:
env = deepcopy(os.environ) env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1" env["LLAMABOARD_ENABLED"] = "1"
env["LLAMABOARD_WORKDIR"] = args["output_dir"]
if args.get("deepspeed", None) is not None: if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1" env["FORCE_TORCHRUN"] = "1"

View File

@ -38,12 +38,15 @@ def abort_process(pid: int) -> None:
r""" r"""
Aborts the processes recursively in a bottom-up way. Aborts the processes recursively in a bottom-up way.
""" """
children = psutil.Process(pid).children() try:
if children: children = psutil.Process(pid).children()
for child in children: if children:
abort_process(child.pid) for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT) os.kill(pid, signal.SIGABRT)
except Exception:
pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":