forked from p83651209/CPM-9G-8B
106 lines
3.1 KiB
Python
106 lines
3.1 KiB
Python
import torch
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
|
|
class DragonflyConfig(PretrainedConfig):
|
|
model_type = "fm9g_dragonfly"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
attribute_map = {
|
|
"num_key_value_heads": "num_kv_heads",
|
|
"hidden_act": "activate_fn",
|
|
"hidden_size": "dim_model",
|
|
"num_attention_heads": "num_heads",
|
|
"intermediate_size": "dim_ff",
|
|
"num_hidden_layers": "num_layers",
|
|
"vocab_size": "vocab_size",
|
|
"rms_norm_eps": "eps",
|
|
"scale_emb": "scale_emb",
|
|
"scale_depth": "scale_depth",
|
|
"scale": "scale",
|
|
"attention_scale": "attention_scale",
|
|
"qk_norm": "qk_norm",
|
|
"ffn_gated": "ffn_gated",
|
|
} # model specific to common
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
|
|
dim_model=4096,
|
|
num_heads=32,
|
|
num_kv_heads=32,
|
|
dim_head=128,
|
|
dim_ff=11008,
|
|
num_layers=32,
|
|
dropout_p=0.0,
|
|
activate_fn="silu",
|
|
scale=False,
|
|
scale_emb: float = 1.0,
|
|
scale_depth: float = -1,
|
|
dim_model_base: int = 256,
|
|
eps=1e-5,
|
|
init_std=0.02,
|
|
dtype="bf16",
|
|
base=10000,
|
|
qk_norm=False,
|
|
tie_lm_head=False,
|
|
max_length=8192,
|
|
pose_prob=0.0,
|
|
pose_scaling_factor=1,
|
|
rope_scaling_type="",
|
|
rope_scaling_factor=1,
|
|
orig_max_length=8192,
|
|
tp=0,
|
|
use_checkpoint=True,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.dim_model = dim_model
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.dim_head = dim_head
|
|
self.dim_ff = dim_ff
|
|
self.num_layers = num_layers
|
|
self.dropout_p = dropout_p
|
|
self.activate_fn = activate_fn
|
|
self.scale = scale
|
|
self.scale_emb = scale_emb
|
|
self._dtype = dtype
|
|
self.dim_model_base = dim_model_base
|
|
self.scale_depth = scale_depth
|
|
self.eps = eps
|
|
self.init_std = init_std
|
|
self.base = base
|
|
self.qk_norm = qk_norm
|
|
self.tie_lm_head = tie_lm_head
|
|
self.use_bfloat16 = True if self._dtype == "bf16" else False
|
|
self.pose_prob = pose_prob
|
|
self.pose_scaling_factor = pose_scaling_factor
|
|
self.rope_scaling_type = rope_scaling_type
|
|
self.rope_scaling_factor = rope_scaling_factor
|
|
self.max_length = max_length
|
|
self.orig_max_length = orig_max_length
|
|
self.use_checkpoint = use_checkpoint
|
|
print("use_checkpoint", self.use_checkpoint)
|
|
self.tp = tp
|
|
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
|
|
|
|
@property
|
|
def scale_width(
|
|
self,
|
|
):
|
|
if self.scale:
|
|
return self.dim_model / self.dim_model_base
|
|
else:
|
|
return 1.0
|
|
|
|
@property
|
|
def dtype(
|
|
self,
|
|
): # -> Any | None:
|
|
if self._dtype == "bf16":
|
|
return torch.bfloat16
|
|
elif self._dtype == "fp16":
|
|
return torch.half
|
|
elif self._dtype == "float32":
|
|
return torch.float
|