tiny update

This commit is contained in:
hiyouga 2023-12-25 18:29:34 +08:00
parent e4bb846c43
commit 5b93d545e2
5 changed files with 59 additions and 22 deletions

View File

@ -173,6 +173,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)

View File

@ -173,6 +173,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)

View File

@ -141,6 +141,9 @@
"nectar_sft": {
"hf_hub_url": "mlinmg/SFT-Nectar"
},
"deepctrl": {
"ms_hub_url": "deepctrl/deepctrl-sft-data"
},
"adgen": {
"hf_hub_url": "HasturOfficial/adgen",
"ms_hub_url": "AI-ModelScope/adgen",

View File

@ -8,9 +8,17 @@ import os
import fire
import json
import torch
from tqdm import tqdm
from collections import OrderedDict
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from typing import Any, Dict
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from typing import Any, Dict, Optional
CONFIG_NAME = "config.json"
@ -19,7 +27,8 @@ CONFIG_NAME = "config.json"
def save_weight(
input_dir: str,
output_dir: str,
shard_size: str
shard_size: str,
save_safetensors: bool
):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in os.listdir(input_dir):
@ -28,7 +37,7 @@ def save_weight(
baichuan2_state_dict.update(shard_weight)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in baichuan2_state_dict.items():
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
if "W_pack" in key:
proj_size = value.size(0) // 3
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
@ -39,14 +48,20 @@ def save_weight(
else:
llama2_state_dict[key] = value
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
for shard_file, shard in shards.items():
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else:
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
@ -71,14 +86,15 @@ def save_config(
def llamafy_baichuan2(
input_dir: str,
output_dir: str,
shard_size: str
shard_size: str,
save_safetensors: Optional[bool] = False
):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size)
save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir)

View File

@ -6,11 +6,19 @@ import os
import fire
import json
import torch
from tqdm import tqdm
from collections import OrderedDict
from safetensors import safe_open
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from transformers.utils import check_min_version
from typing import Any, Dict
from typing import Any, Dict, Optional
try:
check_min_version("4.34.0")
@ -24,7 +32,8 @@ CONFIG_NAME = "config.json"
def save_weight(
input_dir: str,
output_dir: str,
shard_size: str
shard_size: str,
save_safetensors: bool
) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in os.listdir(input_dir):
@ -35,7 +44,7 @@ def save_weight(
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None
for key, value in qwen_state_dict.items():
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None:
torch_dtype = value.dtype
if "wte" in key:
@ -69,14 +78,20 @@ def save_weight(
else:
raise KeyError("Unable to process key {}".format(key))
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
for shard_file, shard in shards.items():
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
else:
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
@ -120,14 +135,15 @@ def save_config(
def llamafy_qwen(
input_dir: str,
output_dir: str,
shard_size: str
shard_size: str,
save_safetensors: Optional[bool] = False
):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
torch_dtype = save_weight(input_dir, output_dir, shard_size)
torch_dtype = save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir, torch_dtype)