diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index ef759140..2570c5d7 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -86,10 +86,8 @@ def load_model_and_tokenizer( # Fix config (for Qwen) if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"): - if model_args.compute_dtype == torch.bfloat16: - setattr(config, "bf16", True) - else: - setattr(config, "fp16", True) + setattr(config, "fp16", model_args.compute_dtype == torch.float16) + setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16) # Set RoPE scaling if model_args.rope_scaling is not None: diff --git a/tests/cal_flops.py b/tests/cal_flops.py new file mode 100644 index 00000000..58ca6cae --- /dev/null +++ b/tests/cal_flops.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Calculates the flops of pre-trained models. +# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 +# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ + +import fire +import torch +from typing import Optional +from deepspeed.accelerator import get_accelerator +from deepspeed.profiling.flops_profiler import get_model_profile + +from llmtuner import ChatModel + + +def calculate( + model_name_or_path: str, + batch_size: Optional[int] = 1, + seq_length: Optional[int] = 256, + flash_attn: Optional[bool] = False +): + with get_accelerator().device(0): + chat_model = ChatModel(dict( + model_name_or_path=model_name_or_path, + template="vanilla", + flash_attn=flash_attn + )) + fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) + input_dict = { + "input_ids": fake_input, + "labels": fake_input.clone() + } + flops, macs, params = get_model_profile( + chat_model.model, + kwargs=input_dict, + print_profile=True, + detailed=True + ) + print("FLOPS:", flops) + print("MACs:", macs) + print("Params:", params) + + +if __name__ == "__main__": + fire.Fire(calculate)