fix llamafy scripts

This commit is contained in:
hiyouga 2024-01-18 00:37:37 +08:00
parent 7ff4c874d2
commit f99140d5e8
3 changed files with 5 additions and 4 deletions

View File

@ -31,7 +31,7 @@ def save_weight(
save_safetensors: bool save_safetensors: bool
): ):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in os.listdir(input_dir): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
baichuan2_state_dict.update(shard_weight) baichuan2_state_dict.update(shard_weight)

View File

@ -32,7 +32,7 @@ def save_weight(
internlm2_config_dict: Dict[str, Any] = json.load(f) internlm2_config_dict: Dict[str, Any] = json.load(f)
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict() internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in os.listdir(input_dir): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
internlm2_state_dict.update(shard_weight) internlm2_state_dict.update(shard_weight)
@ -46,7 +46,7 @@ def save_weight(
elif "attention_norm" in key: elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "wqkv" in key: elif "wqkv" in key:
proj_size = value.size(0) // 3 proj_size = value.size(0)
num_q_heads = internlm2_config_dict["num_attention_heads"] num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"] num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads
@ -95,6 +95,7 @@ def save_config(
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None) llama2_config_dict.pop("auto_map", None)
llama2_config_dict.pop("bias", None) llama2_config_dict.pop("bias", None)
llama2_config_dict.pop("rope_scaling", None)
llama2_config_dict["model_type"] = "llama" llama2_config_dict["model_type"] = "llama"
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:

View File

@ -37,7 +37,7 @@ def save_weight(
save_safetensors: bool save_safetensors: bool
) -> str: ) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in os.listdir(input_dir): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys(): for key in f.keys():