Merge pull request #2746 from stephen-nju/main

fix deepspeed ppo RuntimeError
This commit is contained in:
hoshi-hiyouga 2024-03-09 01:37:00 +08:00 committed by GitHub
commit 516d0ddc66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 3 deletions

View File

@ -60,7 +60,7 @@ def load_model(
"""
init_kwargs = _get_init_kwargs(model_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
patch_config(config, tokenizer, model_args,finetuning_args, init_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import ModelArguments,FinetuningArguments
logger = get_logger(__name__)
@ -265,6 +265,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@ -289,7 +290,8 @@ def patch_config(
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if "device_map" not in init_kwargs: # quant models cannot use auto device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
def patch_model(