upgrade peft, fix #1088 #1411

This commit is contained in:
hiyouga 2023-11-07 16:13:36 +08:00
parent 66a91e1fe3
commit b2a60905f3
15 changed files with 133 additions and 99 deletions

View File

@ -2,7 +2,7 @@ torch>=1.13.1
transformers>=4.31.0,<4.35.0
datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0
peft>=0.6.0
trl>=0.7.2
gradio>=3.38.0,<4.0.0
scipy

View File

@ -59,8 +59,8 @@ def get_dataset(
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None
token=model_args.hf_hub_token,
streaming=data_args.streaming
)
if max_samples is not None: # truncate dataset

View File

@ -257,7 +257,7 @@ def preprocess_dataset(
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.")
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
if training_args.should_log:
try:

View File

@ -2,7 +2,7 @@ IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"]
METHODS = ["full", "freeze", "lora"]

View File

@ -24,10 +24,10 @@ class FinetuningArguments:
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
)
lora_rank: Optional[int] = field(
default=8,
@ -45,11 +45,11 @@ class FinetuningArguments:
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,

View File

@ -22,10 +22,6 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
@ -66,7 +62,7 @@ class ModelArguments:
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_auth_token: Optional[str] = field(
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
@ -87,7 +83,3 @@ class ModelArguments:
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)

View File

@ -1,2 +1,3 @@
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
from llmtuner.tuner.core.loader import load_model_and_tokenizer
from llmtuner.tuner.core.utils import generate_model_card

View File

@ -1,6 +1,9 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import (
PeftModel,
TaskType,
@ -23,8 +26,7 @@ def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
is_mergeable: bool
is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@ -61,7 +63,7 @@ def init_adapter(
latest_checkpoint = None
if model_args.checkpoint_dir is not None:
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
@ -92,10 +94,33 @@ def init_adapter(
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
model.base_model.peft_config = model.peft_config
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> None:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)

View File

@ -25,9 +25,8 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
from llmtuner.tuner.core.utils import prepare_model_for_training
if TYPE_CHECKING:
@ -41,7 +40,7 @@ logger = get_logger(__name__)
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
@ -64,7 +63,7 @@ def load_model_and_tokenizer(
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"token": model_args.hf_hub_token
}
tokenizer = AutoTokenizer.from_pretrained(
@ -99,15 +98,9 @@ def load_model_and_tokenizer(
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support RoPE scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
@ -129,9 +122,6 @@ def load_model_and_tokenizer(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
@ -155,7 +145,6 @@ def load_model_and_tokenizer(
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@ -165,7 +154,7 @@ def load_model_and_tokenizer(
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
@ -175,7 +164,6 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@ -207,7 +195,7 @@ def load_model_and_tokenizer(
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
@ -226,7 +214,7 @@ def load_model_and_tokenizer(
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
load_valuehead_params(model, model_args)
# Prepare model for inference
if not is_trainable:

View File

@ -132,16 +132,12 @@ def get_train_args(
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if model_args.quantization_bit is not None:
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
if not finetuning_args.resume_lora_training:
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
@ -216,11 +212,11 @@ def get_infer_args(
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args

View File

@ -1,13 +1,12 @@
import torch
from types import MethodType
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import FinetuningArguments
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__)
@ -15,8 +14,7 @@ logger = get_logger(__name__)
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None,
output_layer_name: Optional[str] = "lm_head"
quantization_bit: Optional[int] = None
) -> List[str]:
if quantization_bit is not None:
import bitsandbytes as bnb
@ -24,17 +22,35 @@ def find_all_linear_modules(
else:
linear_cls = torch.nn.Linear
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
module_names = set()
for name, module in model.named_modules():
if output_layer_name not in name and isinstance(module, linear_cls):
if (
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
module_names.add(name.split(".")[-1])
if output_layer_name in module_names:
module_names.pop(output_layer_name)
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def generate_model_card(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
}
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
@ -56,26 +72,21 @@ def prepare_model_for_training(
logger.info("Upcasting weights in layernorm in float32.")
if finetuning_args.neft_alpha > 1e-6:
input_embed = model.get_input_embeddings()
if isinstance(input_embed, torch.nn.Embedding):
def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor:
embeddings = input_embed.__class__.forward(self, x)
if self.training:
dims = self.num_embeddings * self.embedding_dim
mag_norm = finetuning_args.neft_alpha / (dims ** 0.5)
embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
return embeddings
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
input_embed.forward = MethodType(noisy_forward, input_embed)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
else:
logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.")
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
@ -86,9 +97,11 @@ def prepare_model_for_training(
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear):
def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor:
return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32)
output_layer.forward = MethodType(forward_in_fp32, output_layer)
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
return args[0].to(output_layer.weight.dtype)
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
output_layer.register_forward_hook(fp32_forward_post_hook)
return model

View File

@ -8,7 +8,7 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
@ -52,13 +52,18 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card())
else:
trainer.create_model_card(**generate_model_card())
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")

View File

@ -1,4 +1,4 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math
from typing import TYPE_CHECKING, Optional, List
@ -6,7 +6,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -38,13 +38,18 @@ def run_pt(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card())
else:
trainer.create_model_card(**generate_model_card())
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")

View File

@ -1,5 +1,4 @@
# Inspired by:
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
@ -7,7 +6,7 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwiseTrainer
@ -47,13 +46,18 @@ def run_rm(
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card())
else:
trainer.create_model_card(**generate_model_card())
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")

View File

@ -1,4 +1,4 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
@ -7,7 +7,7 @@ from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
@ -65,13 +65,18 @@ def run_sft(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card())
else:
trainer.create_model_card(**generate_model_card())
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)