This commit is contained in:
hiyouga 2023-12-01 15:34:50 +08:00
parent a0fde6e421
commit bf6f6aeefe
4 changed files with 27 additions and 19 deletions

View File

@ -156,6 +156,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
@ -171,6 +172,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
</details>

View File

@ -156,6 +156,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
@ -171,6 +172,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
</details>

View File

@ -134,6 +134,9 @@
"webnovel": {
"hf_hub_url": "zxbsmk/webnovel_cn"
},
"nectar_sft": {
"hf_hub_url": "mlinmg/SFT-Nectar"
},
"adgen": {
"hf_hub_url": "HasturOfficial/adgen",
"columns": {
@ -216,6 +219,10 @@
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
"ranking": true
},
"nectar_rm": {
"hf_hub_url": "mlinmg/RLAIF-Nectar",
"ranking": true
},
"wiki_demo": {
"file_name": "wiki_demo.txt",
"file_sha1": "e70375e28eda542a90c68213640cc371898ce181",
@ -266,12 +273,6 @@
"columns": {
"prompt": "content"
}
"nectar_rlaif": {
"hf_hub_url": "mlinmg/RLAIF-Nectar",
"ranking": true
},
"nectar_sft": {
"hf_hub_url": "mlinmg/SFT-Nectar"
},
"starcoder": {
"hf_hub_url": "bigcode/starcoderdata",

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers.modeling_utils import custom_object_save, unwrap_model
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from llmtuner.extras.constants import LOG_FILE_NAME
@ -18,6 +19,16 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@ -25,25 +36,17 @@ class SavePeftModelCallback(TrainerCallback):
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
_save_model_with_valuehead(
model=unwrap_model(kwargs.pop("model")),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(args.output_dir)
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
class LogCallback(TrainerCallback):