# coding=utf-8 # Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import fire import torch import torch.distributed as dist from transformers import AutoConfig from llamafactory.train.tuner import run_exp BASE = 2 # gemm (add + mul) def compute_model_flops( model_name_or_path: str, total_batch_size: int, seq_length: int, include_backward: bool = True, include_recompute: bool = False, include_flashattn: bool = False, ) -> int: r""" Calculates the FLOPs of model per forward/backward pass. """ config = AutoConfig.from_pretrained(model_name_or_path) hidden_size = getattr(config, "hidden_size", None) vocab_size = getattr(config, "vocab_size", None) intermediate_size = getattr(config, "intermediate_size", None) num_attention_heads = getattr(config, "num_attention_heads", None) num_key_value_heads = getattr(config, "num_key_value_heads", None) num_hidden_layers = getattr(config, "num_hidden_layers", None) tie_word_embeddings = getattr(config, "tie_word_embeddings", False) # mlp module mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token # attn projector module q_flops_per_token = BASE * hidden_size * hidden_size o_flops_per_token = BASE * hidden_size * hidden_size k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token # attn sdpa module sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer # embedding module embedding_flops_per_token = hidden_size * vocab_size embedding_flops = total_batch_size * seq_length * embedding_flops_per_token if tie_word_embeddings is False: embedding_flops *= 2 non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops non_embedding_coeff, embedding_coeff = 1, 1 if include_backward: non_embedding_coeff += 2 embedding_coeff += 2 if include_recompute: non_embedding_coeff += 1 total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops if include_flashattn: total_flops += sdpa_flops return total_flops def compute_device_flops(world_size: int) -> float: r""" Calculates the FLOPs of the device capability per second. """ device_name = torch.cuda.get_device_name() if "H100" in device_name or "H800" in device_name: return 989 * 1e12 * world_size elif "A100" in device_name or "A800" in device_name: return 312 * 1e12 * world_size elif "V100" in device_name: return 125 * 1e12 * world_size elif "4090" in device_name: return 98 * 1e12 * world_size else: raise NotImplementedError("Device not supported: {}.".format(device_name)) def calculate_mfu( model_name_or_path: str, batch_size: int = 1, seq_length: int = 1024, num_steps: int = 100, finetuning_type: str = "lora", flash_attn: str = "auto", deepspeed_stage: int = 0, disable_gc: bool = False, liger_kernel: bool = False, unsloth_gc: bool = False, ) -> float: r""" Calculates MFU for given model and hyper-params. Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 """ args = { "model_name_or_path": model_name_or_path, "flash_attn": flash_attn, "disable_gradient_checkpointing": disable_gc, "enable_liger_kernel": liger_kernel, "use_unsloth_gc": unsloth_gc, "stage": "pt", "do_train": True, "finetuning_type": finetuning_type, "dataset": "c4_demo", "cutoff_len": seq_length, "output_dir": os.path.join("saves", "test_mfu"), "logging_strategy": "no", "save_strategy": "no", "save_only_model": True, "overwrite_output_dir": True, "per_device_train_batch_size": batch_size, "max_steps": num_steps, "bf16": True, } if deepspeed_stage in [2, 3]: args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage) run_exp(args) with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: result = json.load(f) if dist.is_initialized(): world_size = dist.get_world_size() else: world_size = 1 total_batch_size = batch_size * world_size mfu_value = ( result["train_steps_per_second"] * compute_model_flops(model_name_or_path, total_batch_size, seq_length) / compute_device_flops(world_size) ) print("MFU: {:.2f}%".format(mfu_value * 100)) if __name__ == "__main__": fire.Fire(calculate_mfu)