This commit is contained in:
hiyouga 2024-06-11 15:38:38 +08:00
parent 90e14a960d
commit 89f2bd8c8c
3 changed files with 27 additions and 19 deletions

View File

@ -1,6 +1,8 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional from typing import Any, Dict, Literal, Optional
from typing_extensions import Self
@dataclass @dataclass
class ModelArguments: class ModelArguments:
@ -216,3 +218,13 @@ class ModelArguments:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return asdict(self) return asdict(self)
@classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length
return new_arg

View File

@ -79,7 +79,7 @@ def patch_config(
if "device_map" not in init_kwargs and model_args.device_map: if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto": if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder

View File

@ -83,15 +83,12 @@ def create_ref_model(
The valuehead parameter is randomly initialized since it is useless for PPO training. The valuehead parameter is randomly initialized since it is useless for PPO training.
""" """
if finetuning_args.ref_model is not None: if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict() ref_model_args = ModelArguments.copyfrom(
ref_model_args_dict.update( model_args,
dict(
model_name_or_path=finetuning_args.ref_model, model_name_or_path=finetuning_args.ref_model,
adapter_name_or_path=finetuning_args.ref_model_adapters, adapter_name_or_path=finetuning_args.ref_model_adapters,
quantization_bit=finetuning_args.ref_model_quantization_bit, quantization_bit=finetuning_args.ref_model_quantization_bit,
) )
)
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments() ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"] tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model( ref_model = load_model(
@ -102,9 +99,11 @@ def create_ref_model(
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
else: else:
tokenizer = load_tokenizer(model_args)["tokenizer"] ref_model_args = ModelArguments.copyfrom(model_args)
ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model( ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
logger.info("Created reference model from the model itself.") logger.info("Created reference model from the model itself.")
@ -139,15 +138,12 @@ def create_reward_model(
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None return None
else: else:
reward_model_args_dict = model_args.to_dict() reward_model_args = ModelArguments.copyfrom(
reward_model_args_dict.update( model_args,
dict(
model_name_or_path=finetuning_args.reward_model, model_name_or_path=finetuning_args.reward_model,
adapter_name_or_path=finetuning_args.reward_model_adapters, adapter_name_or_path=finetuning_args.reward_model_adapters,
quantization_bit=finetuning_args.reward_model_quantization_bit, quantization_bit=finetuning_args.reward_model_quantization_bit,
) )
)
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments() reward_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(reward_model_args)["tokenizer"] tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model( reward_model = load_model(