123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # coding=utf-8
- # Copyright 2024 the LlamaFactory team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import os
- import fire
- import torch
- import torch.distributed as dist
- from transformers import AutoConfig
- from llamafactory.train.tuner import run_exp
- BASE = 2 # gemm (add + mul)
- def compute_model_flops(
- model_name_or_path: str,
- total_batch_size: int,
- seq_length: int,
- include_backward: bool = True,
- include_recompute: bool = False,
- include_flashattn: bool = False,
- ) -> int:
- r"""
- Calculates the FLOPs of model per forward/backward pass.
- """
- config = AutoConfig.from_pretrained(model_name_or_path)
- hidden_size = getattr(config, "hidden_size", None)
- vocab_size = getattr(config, "vocab_size", None)
- intermediate_size = getattr(config, "intermediate_size", None)
- num_attention_heads = getattr(config, "num_attention_heads", None)
- num_key_value_heads = getattr(config, "num_key_value_heads", None)
- num_hidden_layers = getattr(config, "num_hidden_layers", None)
- tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
- # mlp module
- mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
- mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
- # attn projector module
- q_flops_per_token = BASE * hidden_size * hidden_size
- o_flops_per_token = BASE * hidden_size * hidden_size
- k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
- v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
- attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
- attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
- # attn sdpa module
- sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
- sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
- # embedding module
- embedding_flops_per_token = hidden_size * vocab_size
- embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
- if tie_word_embeddings is False:
- embedding_flops *= 2
- non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
- non_embedding_coeff, embedding_coeff = 1, 1
- if include_backward:
- non_embedding_coeff += 2
- embedding_coeff += 2
- if include_recompute:
- non_embedding_coeff += 1
- total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
- if include_flashattn:
- total_flops += sdpa_flops
- return total_flops
- def compute_device_flops(world_size: int) -> float:
- r"""
- Calculates the FLOPs of the device capability per second.
- """
- device_name = torch.cuda.get_device_name()
- if "H100" in device_name or "H800" in device_name:
- return 989 * 1e12 * world_size
- elif "A100" in device_name or "A800" in device_name:
- return 312 * 1e12 * world_size
- elif "V100" in device_name:
- return 125 * 1e12 * world_size
- elif "4090" in device_name:
- return 98 * 1e12 * world_size
- else:
- raise NotImplementedError("Device not supported: {}.".format(device_name))
- def calculate_mfu(
- model_name_or_path: str,
- batch_size: int = 1,
- seq_length: int = 1024,
- num_steps: int = 100,
- finetuning_type: str = "lora",
- flash_attn: str = "auto",
- deepspeed_stage: int = 0,
- disable_gc: bool = False,
- liger_kernel: bool = False,
- unsloth_gc: bool = False,
- ) -> float:
- r"""
- Calculates MFU for given model and hyper-params.
- Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
- """
- args = {
- "model_name_or_path": model_name_or_path,
- "flash_attn": flash_attn,
- "disable_gradient_checkpointing": disable_gc,
- "enable_liger_kernel": liger_kernel,
- "use_unsloth_gc": unsloth_gc,
- "stage": "pt",
- "do_train": True,
- "finetuning_type": finetuning_type,
- "dataset": "c4_demo",
- "cutoff_len": seq_length,
- "output_dir": os.path.join("saves", "test_mfu"),
- "logging_strategy": "no",
- "save_strategy": "no",
- "save_only_model": True,
- "overwrite_output_dir": True,
- "per_device_train_batch_size": batch_size,
- "max_steps": num_steps,
- "bf16": True,
- }
- if deepspeed_stage in [2, 3]:
- args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
- run_exp(args)
- with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
- result = json.load(f)
- if dist.is_initialized():
- world_size = dist.get_world_size()
- else:
- world_size = 1
- total_batch_size = batch_size * world_size
- mfu_value = (
- result["train_steps_per_second"]
- * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
- / compute_device_flops(world_size)
- )
- print("MFU: {:.2f}%".format(mfu_value * 100))
- if __name__ == "__main__":
- fire.Fire(calculate_mfu)
|