This commit is contained in:
hiyouga 2023-11-19 14:15:47 +08:00
parent ff6056405d
commit 1740131d63
8 changed files with 35 additions and 31 deletions

View File

@ -66,8 +66,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
dummy_accelerator = accelerate.Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:

View File

@ -2,4 +2,4 @@
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params

View File

@ -71,7 +71,7 @@ def find_all_linear_modules(
return list(module_names)
def generate_model_card(
def get_modelcard_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"

View File

@ -1,6 +1,5 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
@ -8,10 +7,10 @@ from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -78,8 +77,4 @@ def run_dpo(
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@ -6,7 +6,8 @@ from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -58,8 +59,4 @@ def run_pt(
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@ -6,10 +6,11 @@ from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
from llmtuner.train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -68,8 +69,4 @@ def run_rm(
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@ -7,9 +7,10 @@ from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -90,8 +91,4 @@ def run_sft(
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@ -3,16 +3,35 @@ from typing import TYPE_CHECKING, Literal, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments"
) -> None:
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
return
try:
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
except Exception as err:
logger.warning("Failed to create model card: {}".format(str(err)))
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",