fix #1696
This commit is contained in:
parent
a0fde6e421
commit
bf6f6aeefe
|
@ -156,6 +156,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
|
@ -171,6 +172,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -156,6 +156,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
|
@ -171,6 +172,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -134,6 +134,9 @@
|
|||
"webnovel": {
|
||||
"hf_hub_url": "zxbsmk/webnovel_cn"
|
||||
},
|
||||
"nectar_sft": {
|
||||
"hf_hub_url": "mlinmg/SFT-Nectar"
|
||||
},
|
||||
"adgen": {
|
||||
"hf_hub_url": "HasturOfficial/adgen",
|
||||
"columns": {
|
||||
|
@ -216,6 +219,10 @@
|
|||
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
|
||||
"ranking": true
|
||||
},
|
||||
"nectar_rm": {
|
||||
"hf_hub_url": "mlinmg/RLAIF-Nectar",
|
||||
"ranking": true
|
||||
},
|
||||
"wiki_demo": {
|
||||
"file_name": "wiki_demo.txt",
|
||||
"file_sha1": "e70375e28eda542a90c68213640cc371898ce181",
|
||||
|
@ -266,12 +273,6 @@
|
|||
"columns": {
|
||||
"prompt": "content"
|
||||
}
|
||||
"nectar_rlaif": {
|
||||
"hf_hub_url": "mlinmg/RLAIF-Nectar",
|
||||
"ranking": true
|
||||
},
|
||||
"nectar_sft": {
|
||||
"hf_hub_url": "mlinmg/SFT-Nectar"
|
||||
},
|
||||
"starcoder": {
|
||||
"hf_hub_url": "bigcode/starcoderdata",
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
|
@ -18,6 +19,16 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
|
@ -25,25 +36,17 @@ class SavePeftModelCallback(TrainerCallback):
|
|||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||
model.pretrained_model.config.save_pretrained(args.output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(args.output_dir)
|
||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
|
Loading…
Reference in New Issue