From 5b93d545e2090d8d6db2cee3a047565f834e87f1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 25 Dec 2023 18:29:34 +0800 Subject: [PATCH] tiny update --- README.md | 1 + README_zh.md | 1 + data/dataset_info.json | 3 +++ tests/llamafy_baichuan2.py | 36 ++++++++++++++++++++++++---------- tests/llamafy_qwen.py | 40 ++++++++++++++++++++++++++------------ 5 files changed, 59 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 3cb26ae8..cdcb2046 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) +- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) diff --git a/README_zh.md b/README_zh.md index ac47fbec..2adef5af 100644 --- a/README_zh.md +++ b/README_zh.md @@ -173,6 +173,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) +- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) diff --git a/data/dataset_info.json b/data/dataset_info.json index 9f0273c3..bc031d76 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -141,6 +141,9 @@ "nectar_sft": { "hf_hub_url": "mlinmg/SFT-Nectar" }, + "deepctrl": { + "ms_hub_url": "deepctrl/deepctrl-sft-data" + }, "adgen": { "hf_hub_url": "HasturOfficial/adgen", "ms_hub_url": "AI-ModelScope/adgen", diff --git a/tests/llamafy_baichuan2.py b/tests/llamafy_baichuan2.py index d08eee1c..c7625128 100644 --- a/tests/llamafy_baichuan2.py +++ b/tests/llamafy_baichuan2.py @@ -8,9 +8,17 @@ import os import fire import json import torch +from tqdm import tqdm from collections import OrderedDict -from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from typing import Any, Dict +from safetensors.torch import save_file +from transformers.modeling_utils import ( + shard_checkpoint, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME +) +from typing import Any, Dict, Optional CONFIG_NAME = "config.json" @@ -19,7 +27,8 @@ CONFIG_NAME = "config.json" def save_weight( input_dir: str, output_dir: str, - shard_size: str + shard_size: str, + save_safetensors: bool ): baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() for filepath in os.listdir(input_dir): @@ -28,7 +37,7 @@ def save_weight( baichuan2_state_dict.update(shard_weight) llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() - for key, value in baichuan2_state_dict.items(): + for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"): if "W_pack" in key: proj_size = value.size(0) // 3 llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :] @@ -39,14 +48,20 @@ def save_weight( else: llama2_state_dict[key] = value - shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME) - for shard_file, shard in shards.items(): - torch.save(shard, os.path.join(output_dir, shard_file)) + weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME + shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) + + for shard_file, shard in tqdm(shards.items(), desc="Save weights"): + if save_safetensors: + save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"}) + else: + torch.save(shard, os.path.join(output_dir, shard_file)) if index is None: print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) else: - with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME + with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: json.dump(index, f, indent=2, sort_keys=True) print("Model weights saved in {}".format(output_dir)) @@ -71,14 +86,15 @@ def save_config( def llamafy_baichuan2( input_dir: str, output_dir: str, - shard_size: str + shard_size: str, + save_safetensors: Optional[bool] = False ): try: os.makedirs(output_dir, exist_ok=False) except Exception as e: raise print("Output dir already exists", e) - save_weight(input_dir, output_dir, shard_size) + save_weight(input_dir, output_dir, shard_size, save_safetensors) save_config(input_dir, output_dir) diff --git a/tests/llamafy_qwen.py b/tests/llamafy_qwen.py index 8b9fc395..354028e8 100644 --- a/tests/llamafy_qwen.py +++ b/tests/llamafy_qwen.py @@ -6,11 +6,19 @@ import os import fire import json import torch +from tqdm import tqdm from collections import OrderedDict from safetensors import safe_open -from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from safetensors.torch import save_file +from transformers.modeling_utils import ( + shard_checkpoint, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME +) from transformers.utils import check_min_version -from typing import Any, Dict +from typing import Any, Dict, Optional try: check_min_version("4.34.0") @@ -24,7 +32,8 @@ CONFIG_NAME = "config.json" def save_weight( input_dir: str, output_dir: str, - shard_size: str + shard_size: str, + save_safetensors: bool ) -> str: qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() for filepath in os.listdir(input_dir): @@ -35,7 +44,7 @@ def save_weight( llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() torch_dtype = None - for key, value in qwen_state_dict.items(): + for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"): if torch_dtype is None: torch_dtype = value.dtype if "wte" in key: @@ -69,14 +78,20 @@ def save_weight( else: raise KeyError("Unable to process key {}".format(key)) - shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME) - for shard_file, shard in shards.items(): - torch.save(shard, os.path.join(output_dir, shard_file)) + weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME + shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) + + for shard_file, shard in tqdm(shards.items(), desc="Save weights"): + if save_safetensors: + save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"}) + else: + torch.save(shard, os.path.join(output_dir, shard_file)) if index is None: - print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) + print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) else: - with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME + with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: json.dump(index, f, indent=2, sort_keys=True) print("Model weights saved in {}".format(output_dir)) @@ -120,15 +135,16 @@ def save_config( def llamafy_qwen( input_dir: str, output_dir: str, - shard_size: str + shard_size: str, + save_safetensors: Optional[bool] = False ): try: os.makedirs(output_dir, exist_ok=False) except Exception as e: raise print("Output dir already exists", e) - torch_dtype = save_weight(input_dir, output_dir, shard_size) - save_config(input_dir, output_dir, torch_dtype) + torch_dtype = save_weight(input_dir, output_dir, shard_size, save_safetensors) + save_config(input_dir, output_dir, torch_dtype) if __name__ == "__main__":