cal_mfu.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # coding=utf-8
  2. # Copyright 2024 the LlamaFactory team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import json
  16. import os
  17. import fire
  18. import torch
  19. import torch.distributed as dist
  20. from transformers import AutoConfig
  21. from llamafactory.train.tuner import run_exp
  22. BASE = 2 # gemm (add + mul)
  23. def compute_model_flops(
  24. model_name_or_path: str,
  25. total_batch_size: int,
  26. seq_length: int,
  27. include_backward: bool = True,
  28. include_recompute: bool = False,
  29. include_flashattn: bool = False,
  30. ) -> int:
  31. r"""
  32. Calculates the FLOPs of model per forward/backward pass.
  33. """
  34. config = AutoConfig.from_pretrained(model_name_or_path)
  35. hidden_size = getattr(config, "hidden_size", None)
  36. vocab_size = getattr(config, "vocab_size", None)
  37. intermediate_size = getattr(config, "intermediate_size", None)
  38. num_attention_heads = getattr(config, "num_attention_heads", None)
  39. num_key_value_heads = getattr(config, "num_key_value_heads", None)
  40. num_hidden_layers = getattr(config, "num_hidden_layers", None)
  41. tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
  42. # mlp module
  43. mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
  44. mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
  45. # attn projector module
  46. q_flops_per_token = BASE * hidden_size * hidden_size
  47. o_flops_per_token = BASE * hidden_size * hidden_size
  48. k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
  49. v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
  50. attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
  51. attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
  52. # attn sdpa module
  53. sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
  54. sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
  55. # embedding module
  56. embedding_flops_per_token = hidden_size * vocab_size
  57. embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
  58. if tie_word_embeddings is False:
  59. embedding_flops *= 2
  60. non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
  61. non_embedding_coeff, embedding_coeff = 1, 1
  62. if include_backward:
  63. non_embedding_coeff += 2
  64. embedding_coeff += 2
  65. if include_recompute:
  66. non_embedding_coeff += 1
  67. total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
  68. if include_flashattn:
  69. total_flops += sdpa_flops
  70. return total_flops
  71. def compute_device_flops(world_size: int) -> float:
  72. r"""
  73. Calculates the FLOPs of the device capability per second.
  74. """
  75. device_name = torch.cuda.get_device_name()
  76. if "H100" in device_name or "H800" in device_name:
  77. return 989 * 1e12 * world_size
  78. elif "A100" in device_name or "A800" in device_name:
  79. return 312 * 1e12 * world_size
  80. elif "V100" in device_name:
  81. return 125 * 1e12 * world_size
  82. elif "4090" in device_name:
  83. return 98 * 1e12 * world_size
  84. else:
  85. raise NotImplementedError("Device not supported: {}.".format(device_name))
  86. def calculate_mfu(
  87. model_name_or_path: str,
  88. batch_size: int = 1,
  89. seq_length: int = 1024,
  90. num_steps: int = 100,
  91. finetuning_type: str = "lora",
  92. flash_attn: str = "auto",
  93. deepspeed_stage: int = 0,
  94. disable_gc: bool = False,
  95. liger_kernel: bool = False,
  96. unsloth_gc: bool = False,
  97. ) -> float:
  98. r"""
  99. Calculates MFU for given model and hyper-params.
  100. Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
  101. """
  102. args = {
  103. "model_name_or_path": model_name_or_path,
  104. "flash_attn": flash_attn,
  105. "disable_gradient_checkpointing": disable_gc,
  106. "enable_liger_kernel": liger_kernel,
  107. "use_unsloth_gc": unsloth_gc,
  108. "stage": "pt",
  109. "do_train": True,
  110. "finetuning_type": finetuning_type,
  111. "dataset": "c4_demo",
  112. "cutoff_len": seq_length,
  113. "output_dir": os.path.join("saves", "test_mfu"),
  114. "logging_strategy": "no",
  115. "save_strategy": "no",
  116. "save_only_model": True,
  117. "overwrite_output_dir": True,
  118. "per_device_train_batch_size": batch_size,
  119. "max_steps": num_steps,
  120. "bf16": True,
  121. }
  122. if deepspeed_stage in [2, 3]:
  123. args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
  124. run_exp(args)
  125. with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
  126. result = json.load(f)
  127. if dist.is_initialized():
  128. world_size = dist.get_world_size()
  129. else:
  130. world_size = 1
  131. total_batch_size = batch_size * world_size
  132. mfu_value = (
  133. result["train_steps_per_second"]
  134. * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
  135. / compute_device_flops(world_size)
  136. )
  137. print("MFU: {:.2f}%".format(mfu_value * 100))
  138. if __name__ == "__main__":
  139. fire.Fire(calculate_mfu)