From cdb7f82869b07d9d5d31b7b2aaf6b033bd00e32e Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 8 Mar 2024 11:48:26 +0800 Subject: [PATCH 1/2] fix ppo runtime error --- src/llmtuner/model/loader.py | 2 +- src/llmtuner/model/patcher.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index e5b3bdd1..521e9e2f 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -61,7 +61,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) - patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + patch_config(config, tokenizer, model_args,finetuning_args, init_kwargs, is_trainable) model = None if is_trainable and model_args.use_unsloth: diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 4ecfcc86..a5d9e3b2 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -264,6 +264,7 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: @@ -288,7 +289,8 @@ def patch_config( if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage if "device_map" not in init_kwargs: # quant models cannot use auto device map - init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} + if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map + init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} def patch_model( From aa71571b773c5dc527b17219ec87828e4455b330 Mon Sep 17 00:00:00 2001 From: stephen_zhu Date: Fri, 8 Mar 2024 12:47:44 +0800 Subject: [PATCH 2/2] update --- src/llmtuner/model/patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index a5d9e3b2..00e42147 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from ..hparams import ModelArguments + from ..hparams import ModelArguments,FinetuningArguments logger = get_logger(__name__)