CPM-9G-8B/FM_9G/fm9g/utils/training_stats.py

73 lines
2.1 KiB
Python
Raw Normal View History

2024-07-15 14:27:10 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: hsd9026 <shengdinghu@gmail.com>
# @date: 2023/07/07
#
import copy
from .log import logger
def num_parameters(model):
"""Return the number of parameters of a model"""
total_params = 0
for param_name, param in model.state_dict().items():
# print(param_name, param.numel())
total_params += param.numel()
return total_params
def num_non_embedding_parameters(model):
"""Return the number of parameters of a model"""
total_params = 0
for param_name, param in model.state_dict().items():
if ("embed" in param_name) or ("lm_head" in param_name):
continue
# print(param_name, param.numel())
total_params += param.numel()
return total_params
def estimate_parameters(config):
"""Estimate the number of parameters of a model given its config, should be equal to `num_parameters(model)`"""
# embedding parameters
embedding_params = config.vocab_size * config.dim_model
self_attn_params = 4 * config.dim_model * config.dim_head * config.num_heads * config.num_layers
ff_params = 3 * config.dim_model * config.dim_ff * config.num_layers
layernorm_parameters = 2 * config.dim_model * config.num_layers
output_layernorm_parameters = config.dim_model
positional_bias = 32
total_params = (
embedding_params
+ self_attn_params
+ ff_params
+ layernorm_parameters
+ output_layernorm_parameters
+ positional_bias
)
params_without_embeddings = total_params - embedding_params
return total_params, params_without_embeddings
def get_flops_per_token(config):
"""An estimated version of pfdays per token, i.e., the 6N in equation: Computation = 6 N * D."""
_, N = estimate_parameters(config)
logger.info(">>>>>> pfdays_per_token >>>> {:,.0f}".format(N))
# evaluating a forward pass
C_forward = 6 * N # + 2 * n_layer * n_ctx * d_model
# unit = 10**15 * 3600 * 24
# C_forward = C_forward / unit
return C_forward