diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 6a906c74..9cfdd84d 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -66,8 +66,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def get_current_device() -> str: import accelerate - from accelerate import Accelerator - dummy_accelerator = Accelerator() + dummy_accelerator = accelerate.Accelerator() if accelerate.utils.is_xpu_available(): return "xpu:{}".format(dummy_accelerator.local_process_index) else: diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index fb9a05e7..f12acb58 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -2,4 +2,4 @@ from llmtuner.model.loader import load_model_and_tokenizer from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args -from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params +from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 5c0b1655..20c35b06 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -71,7 +71,7 @@ def find_all_linear_modules( return list(module_names) -def generate_model_card( +def get_modelcard_args( model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments" diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 5281f4e4..6b5a222d 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -1,6 +1,5 @@ # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py -from peft import PeftModel from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments @@ -8,10 +7,10 @@ from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss from llmtuner.hparams import ModelArguments -from llmtuner.model import generate_model_card, load_model_and_tokenizer -from llmtuner.train.utils import create_ref_model +from llmtuner.model import load_model_and_tokenizer from llmtuner.train.dpo.collator import DPODataCollatorWithPadding from llmtuner.train.dpo.trainer import CustomDPOTrainer +from llmtuner.train.utils import create_modelcard_and_push, create_ref_model if TYPE_CHECKING: from transformers import TrainerCallback @@ -78,8 +77,4 @@ def run_dpo( trainer.save_metrics("eval", metrics) # Create model card - if training_args.do_train: - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llmtuner/train/pt/workflow.py b/src/llmtuner/train/pt/workflow.py index 41bf31ba..eadfa001 100644 --- a/src/llmtuner/train/pt/workflow.py +++ b/src/llmtuner/train/pt/workflow.py @@ -6,7 +6,8 @@ from transformers import DataCollatorForLanguageModeling, Trainer from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.ploting import plot_loss -from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.model import load_model_and_tokenizer +from llmtuner.train.utils import create_modelcard_and_push if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -58,8 +59,4 @@ def run_pt( trainer.save_metrics("eval", metrics) # Create model card - if training_args.do_train: - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index 06f39702..a2e97e8d 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -6,10 +6,11 @@ from transformers import Seq2SeqTrainingArguments from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.model import load_model_and_tokenizer from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.train.rm.metric import compute_accuracy from llmtuner.train.rm.trainer import PairwiseTrainer +from llmtuner.train.utils import create_modelcard_and_push if TYPE_CHECKING: from transformers import TrainerCallback @@ -68,8 +69,4 @@ def run_rm( trainer.save_predictions(predict_results) # Create model card - if training_args.do_train: - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 8a802c9b..4e504903 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -7,9 +7,10 @@ from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss -from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.model import load_model_and_tokenizer from llmtuner.train.sft.metric import ComputeMetrics from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer +from llmtuner.train.utils import create_modelcard_and_push if TYPE_CHECKING: from transformers import TrainerCallback @@ -90,8 +91,4 @@ def run_sft( trainer.save_predictions(predict_results) # Create model card - if training_args.do_train: - if training_args.push_to_hub: - trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args)) - else: - trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args)) + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index f41c7cc7..6b40f33b 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -3,16 +3,35 @@ from typing import TYPE_CHECKING, Literal, Union from llmtuner.extras.logging import get_logger from llmtuner.hparams import ModelArguments, FinetuningArguments -from llmtuner.model import load_model_and_tokenizer, load_valuehead_params +from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, Trainer from transformers.modeling_utils import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead + from llmtuner.hparams import DataArguments logger = get_logger(__name__) +def create_modelcard_and_push( + trainer: "Trainer", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments" +) -> None: + if training_args.do_train: + if training_args.push_to_hub: + trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args)) + return + try: + trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args)) + except Exception as err: + logger.warning("Failed to create model card: {}".format(str(err))) + + def create_ref_model( model_args: "ModelArguments", finetuning_args: "FinetuningArguments",