fix #4198
This commit is contained in:
parent
90e14a960d
commit
89f2bd8c8c
|
@ -1,6 +1,8 @@
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
|
@ -216,3 +218,13 @@ class ModelArguments:
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
||||||
|
arg_dict = old_arg.to_dict()
|
||||||
|
arg_dict.update(**kwargs)
|
||||||
|
new_arg = cls(**arg_dict)
|
||||||
|
new_arg.compute_dtype = old_arg.compute_dtype
|
||||||
|
new_arg.device_map = old_arg.device_map
|
||||||
|
new_arg.model_max_length = old_arg.model_max_length
|
||||||
|
return new_arg
|
||||||
|
|
|
@ -79,7 +79,7 @@ def patch_config(
|
||||||
if "device_map" not in init_kwargs and model_args.device_map:
|
if "device_map" not in init_kwargs and model_args.device_map:
|
||||||
init_kwargs["device_map"] = model_args.device_map
|
init_kwargs["device_map"] = model_args.device_map
|
||||||
|
|
||||||
if init_kwargs["device_map"] == "auto":
|
if init_kwargs.get("device_map", None) == "auto":
|
||||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -83,15 +83,12 @@ def create_ref_model(
|
||||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||||
"""
|
"""
|
||||||
if finetuning_args.ref_model is not None:
|
if finetuning_args.ref_model is not None:
|
||||||
ref_model_args_dict = model_args.to_dict()
|
ref_model_args = ModelArguments.copyfrom(
|
||||||
ref_model_args_dict.update(
|
model_args,
|
||||||
dict(
|
|
||||||
model_name_or_path=finetuning_args.ref_model,
|
model_name_or_path=finetuning_args.ref_model,
|
||||||
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||||
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
|
||||||
ref_finetuning_args = FinetuningArguments()
|
ref_finetuning_args = FinetuningArguments()
|
||||||
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
||||||
ref_model = load_model(
|
ref_model = load_model(
|
||||||
|
@ -102,9 +99,11 @@ def create_ref_model(
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
ref_model = None
|
ref_model = None
|
||||||
else:
|
else:
|
||||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
ref_model_args = ModelArguments.copyfrom(model_args)
|
||||||
|
ref_finetuning_args = FinetuningArguments()
|
||||||
|
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
||||||
ref_model = load_model(
|
ref_model = load_model(
|
||||||
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
)
|
)
|
||||||
logger.info("Created reference model from the model itself.")
|
logger.info("Created reference model from the model itself.")
|
||||||
|
|
||||||
|
@ -139,15 +138,12 @@ def create_reward_model(
|
||||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
reward_model_args_dict = model_args.to_dict()
|
reward_model_args = ModelArguments.copyfrom(
|
||||||
reward_model_args_dict.update(
|
model_args,
|
||||||
dict(
|
|
||||||
model_name_or_path=finetuning_args.reward_model,
|
model_name_or_path=finetuning_args.reward_model,
|
||||||
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||||
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
|
||||||
reward_finetuning_args = FinetuningArguments()
|
reward_finetuning_args = FinetuningArguments()
|
||||||
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
|
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
|
||||||
reward_model = load_model(
|
reward_model = load_model(
|
||||||
|
|
Loading…
Reference in New Issue