forked from p04798526/LLaMA-Factory-Mirror
fix llamafy scripts
This commit is contained in:
parent
7ff4c874d2
commit
f99140d5e8
|
@ -31,7 +31,7 @@ def save_weight(
|
||||||
save_safetensors: bool
|
save_safetensors: bool
|
||||||
):
|
):
|
||||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in os.listdir(input_dir):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||||
baichuan2_state_dict.update(shard_weight)
|
baichuan2_state_dict.update(shard_weight)
|
||||||
|
|
|
@ -32,7 +32,7 @@ def save_weight(
|
||||||
internlm2_config_dict: Dict[str, Any] = json.load(f)
|
internlm2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in os.listdir(input_dir):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||||
internlm2_state_dict.update(shard_weight)
|
internlm2_state_dict.update(shard_weight)
|
||||||
|
@ -46,7 +46,7 @@ def save_weight(
|
||||||
elif "attention_norm" in key:
|
elif "attention_norm" in key:
|
||||||
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
|
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
|
||||||
elif "wqkv" in key:
|
elif "wqkv" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0)
|
||||||
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
||||||
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
||||||
q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads
|
q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads
|
||||||
|
@ -95,6 +95,7 @@ def save_config(
|
||||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||||
llama2_config_dict.pop("auto_map", None)
|
llama2_config_dict.pop("auto_map", None)
|
||||||
llama2_config_dict.pop("bias", None)
|
llama2_config_dict.pop("bias", None)
|
||||||
|
llama2_config_dict.pop("rope_scaling", None)
|
||||||
llama2_config_dict["model_type"] = "llama"
|
llama2_config_dict["model_type"] = "llama"
|
||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
|
|
|
@ -37,7 +37,7 @@ def save_weight(
|
||||||
save_safetensors: bool
|
save_safetensors: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in os.listdir(input_dir):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||||
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
|
|
Loading…
Reference in New Issue