From 89f2bd8c8c035181927bd530a7ffc733407d674c Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 11 Jun 2024 15:38:38 +0800 Subject: [PATCH] fix #4198 --- src/llamafactory/hparams/model_args.py | 12 ++++++++++ src/llamafactory/model/patcher.py | 2 +- src/llamafactory/train/trainer_utils.py | 32 +++++++++++-------------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 6352a420..71467770 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -1,6 +1,8 @@ from dataclasses import asdict, dataclass, field from typing import Any, Dict, Literal, Optional +from typing_extensions import Self + @dataclass class ModelArguments: @@ -216,3 +218,13 @@ class ModelArguments: def to_dict(self) -> Dict[str, Any]: return asdict(self) + + @classmethod + def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + arg_dict = old_arg.to_dict() + arg_dict.update(**kwargs) + new_arg = cls(**arg_dict) + new_arg.compute_dtype = old_arg.compute_dtype + new_arg.device_map = old_arg.device_map + new_arg.model_max_length = old_arg.model_max_length + return new_arg diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 87c92315..18221a10 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -79,7 +79,7 @@ def patch_config( if "device_map" not in init_kwargs and model_args.device_map: init_kwargs["device_map"] = model_args.device_map - if init_kwargs["device_map"] == "auto": + if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 0ddcdb11..7e9cc881 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -83,15 +83,12 @@ def create_ref_model( The valuehead parameter is randomly initialized since it is useless for PPO training. """ if finetuning_args.ref_model is not None: - ref_model_args_dict = model_args.to_dict() - ref_model_args_dict.update( - dict( - model_name_or_path=finetuning_args.ref_model, - adapter_name_or_path=finetuning_args.ref_model_adapters, - quantization_bit=finetuning_args.ref_model_quantization_bit, - ) + ref_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.ref_model, + adapter_name_or_path=finetuning_args.ref_model_adapters, + quantization_bit=finetuning_args.ref_model_quantization_bit, ) - ref_model_args = ModelArguments(**ref_model_args_dict) ref_finetuning_args = FinetuningArguments() tokenizer = load_tokenizer(ref_model_args)["tokenizer"] ref_model = load_model( @@ -102,9 +99,11 @@ def create_ref_model( if finetuning_args.finetuning_type == "lora": ref_model = None else: - tokenizer = load_tokenizer(model_args)["tokenizer"] + ref_model_args = ModelArguments.copyfrom(model_args) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] ref_model = load_model( - tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead ) logger.info("Created reference model from the model itself.") @@ -139,15 +138,12 @@ def create_reward_model( logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) return None else: - reward_model_args_dict = model_args.to_dict() - reward_model_args_dict.update( - dict( - model_name_or_path=finetuning_args.reward_model, - adapter_name_or_path=finetuning_args.reward_model_adapters, - quantization_bit=finetuning_args.reward_model_quantization_bit, - ) + reward_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.reward_model, + adapter_name_or_path=finetuning_args.reward_model_adapters, + quantization_bit=finetuning_args.reward_model_quantization_bit, ) - reward_model_args = ModelArguments(**reward_model_args_dict) reward_finetuning_args = FinetuningArguments() tokenizer = load_tokenizer(reward_model_args)["tokenizer"] reward_model = load_model(