fix safetensors

This commit is contained in:
hiyouga 2024-02-18 18:12:16 +08:00
parent 22acab8aff
commit e4e86a73f1
1 changed files with 12 additions and 5 deletions

View File

@ -26,6 +26,10 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
def block_expansion(
model_name_or_path: str,
output_dir: str,
@ -41,8 +45,13 @@ def block_expansion(
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(output_dir)
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
if save_safetensors:
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
torch_dtype="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
@ -58,7 +67,7 @@ def block_expansion(
for i in range(num_layers):
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
output_state_dict[key.replace(".{:d}.".format(i), ".{:d}.".format(layer_cnt))] = value
output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i))
layer_cnt += 1
@ -66,11 +75,9 @@ def block_expansion(
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
if "down_proj" in key or "o_proj" in key:
output_state_dict[key.replace(".{:d}.".format(i), ".{:d}.".format(layer_cnt))] = (
torch.zeros_like(value)
)
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else:
output_state_dict[key.replace(".{:d}.".format(i), ".{:d}.".format(layer_cnt))] = value
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
layer_cnt += 1