From eb178eaff390a1dc342cc35ab8c7820d654f3717 Mon Sep 17 00:00:00 2001 From: marko1616 Date: Sat, 30 Mar 2024 23:45:04 +0800 Subject: [PATCH 1/2] Fix Llama model save for full param train --- src/llmtuner/model/patcher.py | 9 +++++++++ src/llmtuner/train/tuner.py | 8 -------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index cb55f5ed..e3d7539f 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -312,6 +312,15 @@ def patch_config( def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool ) -> None: + #Config check and fix + gen_config = model.generation_config + if not gen_config.do_sample and ( + (gen_config.temperature is not None and gen_config.temperature != 1.0) + or (gen_config.top_p is not None and gen_config.top_p != 1.0) + or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) + ): + gen_config.do_sample = True + if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index a03955d5..1b8e3cb7 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -64,14 +64,6 @@ def export_model(args: Optional[Dict[str, Any]] = None): for param in model.parameters(): param.data = param.data.to(output_dtype) - gen_config = model.generation_config # check and fix generation config - if not gen_config.do_sample and ( - (gen_config.temperature is not None and gen_config.temperature != 1.0) - or (gen_config.top_p is not None and gen_config.top_p != 1.0) - or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) - ): - gen_config.do_sample = True - model.save_pretrained( save_directory=model_args.export_dir, max_shard_size="{}GB".format(model_args.export_size), From d9a5134617d494ef13ba73f9c540123e89a8c29c Mon Sep 17 00:00:00 2001 From: marko1616 Date: Sat, 30 Mar 2024 23:46:55 +0800 Subject: [PATCH 2/2] fix blank line contains whitespace --- src/llmtuner/model/patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index e3d7539f..03ca0096 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -320,7 +320,7 @@ def patch_model( or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) ): gen_config.do_sample = True - + if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model)