Merge pull request #2746 from stephen-nju/main
fix deepspeed ppo RuntimeError
This commit is contained in:
commit
516d0ddc66
|
@ -60,7 +60,7 @@ def load_model(
|
|||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
patch_config(config, tokenizer, model_args,finetuning_args, init_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
|
|
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import ModelArguments,FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -265,6 +265,7 @@ def patch_config(
|
|||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
|
@ -289,6 +290,7 @@ def patch_config(
|
|||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
|
||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue