profiler.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import time
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from functools import partial
  9. from typing import List, Optional
  10. from collections import OrderedDict
  11. import numpy as np
  12. from deepspeed.accelerator import get_accelerator
  13. from deepspeed.utils import logger
  14. from deepspeed.moe.layer import MoE
  15. from deepspeed.utils.timer import FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER
  16. Tensor = torch.Tensor
  17. module_flop_count = []
  18. module_mac_count = []
  19. old_functions = {}
  20. DEFAULT_PRECISION = 2
  21. class FlopsProfiler(object):
  22. """Measures the latency, number of estimated floating-point operations and parameters of each module in a PyTorch model.
  23. The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
  24. The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
  25. When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file and no user code change is required.
  26. If using the profiler as a standalone package, one imports the flops_profiler package and use the APIs.
  27. Here is an example for usage in a typical training workflow:
  28. .. code-block:: python
  29. model = Model()
  30. prof = FlopsProfiler(model)
  31. for step, batch in enumerate(data_loader):
  32. if step == profile_step:
  33. prof.start_profile()
  34. loss = model(batch)
  35. if step == profile_step:
  36. flops = prof.get_total_flops(as_string=True)
  37. params = prof.get_total_params(as_string=True)
  38. prof.print_model_profile(profile_step=profile_step)
  39. prof.end_profile()
  40. loss.backward()
  41. optimizer.step()
  42. To profile a trained model in inference, use the `get_model_profile` API.
  43. Args:
  44. object (torch.nn.Module): The PyTorch model to profile.
  45. """
  46. def __init__(self, model, ds_engine=None, recompute_fwd_factor=0.0):
  47. self.model = model
  48. self.ds_engine = ds_engine
  49. self.recompute_fwd_factor = recompute_fwd_factor
  50. self.started = False
  51. self.func_patched = False
  52. def start_profile(self, ignore_list=None):
  53. """Starts profiling.
  54. Extra attributes are added recursively to all the modules and the profiled torch.nn.functionals are monkey patched.
  55. Args:
  56. ignore_list (list, optional): the list of modules to ignore while profiling. Defaults to None.
  57. """
  58. logger.info("Flops profiler started")
  59. self.reset_profile()
  60. _patch_functionals()
  61. _patch_tensor_methods()
  62. def register_module_hooks(module, ignore_list):
  63. if ignore_list and type(module) in ignore_list:
  64. return
  65. # if computing the flops of a module directly
  66. if type(module) in MODULE_HOOK_MAPPING:
  67. if not hasattr(module, "__flops_handle__"):
  68. module.__flops_handle__ = module.register_forward_hook(MODULE_HOOK_MAPPING[type(module)])
  69. return
  70. # if computing the flops of the functionals in a module
  71. def pre_hook(module, input):
  72. module_flop_count.append([])
  73. module_mac_count.append([])
  74. if not hasattr(module, "__pre_hook_handle__"):
  75. module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
  76. def post_hook(module, input, output):
  77. if module_flop_count:
  78. module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]])
  79. module_flop_count.pop()
  80. module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]])
  81. module_mac_count.pop()
  82. if not hasattr(module, "__post_hook_handle__"):
  83. module.__post_hook_handle__ = module.register_forward_hook(post_hook)
  84. def start_time_hook(module, input):
  85. get_accelerator().synchronize()
  86. module.__start_time__ = time.time()
  87. if not hasattr(module, "__start_time_hook_handle"):
  88. module.__start_time_hook_handle__ = module.register_forward_pre_hook(start_time_hook)
  89. def end_time_hook(module, input, output):
  90. get_accelerator().synchronize()
  91. module.__duration__ += time.time() - module.__start_time__
  92. if not hasattr(module, "__end_time_hook_handle__"):
  93. module.__end_time_hook_handle__ = module.register_forward_hook(end_time_hook)
  94. self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
  95. self.started = True
  96. self.func_patched = True
  97. def stop_profile(self):
  98. """Stop profiling.
  99. All torch.nn.functionals are restored to their originals.
  100. """
  101. if self.started and self.func_patched:
  102. _reload_functionals()
  103. _reload_tensor_methods()
  104. self.func_patched = False
  105. def remove_profile_attrs(module):
  106. if hasattr(module, "__pre_hook_handle__"):
  107. module.__pre_hook_handle__.remove()
  108. del module.__pre_hook_handle__
  109. if hasattr(module, "__post_hook_handle__"):
  110. module.__post_hook_handle__.remove()
  111. del module.__post_hook_handle__
  112. if hasattr(module, "__flops_handle__"):
  113. module.__flops_handle__.remove()
  114. del module.__flops_handle__
  115. if hasattr(module, "__start_time_hook_handle__"):
  116. module.__start_time_hook_handle__.remove()
  117. del module.__start_time_hook_handle__
  118. if hasattr(module, "__end_time_hook_handle__"):
  119. module.__end_time_hook_handle__.remove()
  120. del module.__end_time_hook_handle__
  121. self.model.apply(remove_profile_attrs)
  122. def reset_profile(self):
  123. """Resets the profiling.
  124. Adds or resets the extra attributes.
  125. """
  126. def get_param_count_and_ep(param):
  127. """
  128. Return the number of parameters in the layer, whether the layer is an MoE layer,
  129. and its expert parallelism size if so
  130. """
  131. prefix = 'ep_size_'
  132. offset = len(prefix)
  133. expert_parallelism = 0
  134. if getattr(param, "group_name", "").startswith(prefix):
  135. try:
  136. expert_parallelism = int(param.group_name[offset:])
  137. except ValueError:
  138. pass
  139. return param.numel(), expert_parallelism, param.element_size()
  140. def add_or_reset_attrs(module):
  141. module.__flops__ = 0
  142. module.__macs__ = 0
  143. module.__params__ = module.__expert_params__ = module.__model_expert_params__ = 0
  144. parameters = (get_param_count_and_ep(p) for p in module.parameters())
  145. for num_params, expert_parallelism, per_param_size in parameters:
  146. params = num_params if not expert_parallelism else 0
  147. expert_params = num_params if expert_parallelism else 0
  148. # number of expert parameters taking into account other expert parallel groups
  149. model_expert_params = num_params * expert_parallelism
  150. module.__params__ += params
  151. module.__expert_params__ += expert_params
  152. module.__model_expert_params__ += model_expert_params
  153. module.__start_time__ = 0
  154. module.__duration__ = 0
  155. self.model.apply(add_or_reset_attrs)
  156. def end_profile(self):
  157. """Ends profiling.
  158. The added attributes and handles are removed recursively on all the modules.
  159. """
  160. if not self.started:
  161. return
  162. self.stop_profile()
  163. self.started = False
  164. def remove_profile_attrs(module):
  165. if hasattr(module, "__flops__"):
  166. del module.__flops__
  167. if hasattr(module, "__macs__"):
  168. del module.__macs__
  169. if hasattr(module, "__params__"):
  170. del module.__params__
  171. if hasattr(module, "__expert_params__"):
  172. del module.__expert_params__
  173. if hasattr(module, "__model_expert_params__"):
  174. del module.__model_expert_params__
  175. if hasattr(module, "__start_time__"):
  176. del module.__start_time__
  177. if hasattr(module, "__duration__"):
  178. del module.__duration__
  179. self.model.apply(remove_profile_attrs)
  180. logger.info("Flops profiler finished")
  181. def get_total_flops(self, as_string=False):
  182. """Returns the total flops of the model.
  183. Args:
  184. as_string (bool, optional): whether to output the flops as string. Defaults to False.
  185. Returns:
  186. The number of multiply-accumulate operations of the model forward pass.
  187. """
  188. total_flops = get_module_flops(self.model)
  189. return number_to_string(total_flops) if as_string else total_flops
  190. def get_total_macs(self, as_string=False):
  191. """Returns the total MACs of the model.
  192. Args:
  193. as_string (bool, optional): whether to output the flops as string. Defaults to False.
  194. Returns:
  195. The number of multiply-accumulate operations of the model forward pass.
  196. """
  197. total_macs = get_module_macs(self.model)
  198. return macs_to_string(total_macs) if as_string else total_macs
  199. def get_total_duration(self, as_string=False):
  200. """Returns the total duration of the model forward pass.
  201. Args:
  202. as_string (bool, optional): whether to output the duration as string. Defaults to False.
  203. Returns:
  204. The latency of the model forward pass.
  205. """
  206. total_duration = get_module_duration(self.model)
  207. return duration_to_string(total_duration) if as_string else total_duration
  208. def get_total_params(self, as_string=False):
  209. """Returns the total number of parameters stored per rank.
  210. Args:
  211. as_string (bool, optional): whether to output the parameters as string. Defaults to False.
  212. Returns:
  213. The total number of parameters stored per rank.
  214. """
  215. total_params = self.model.__expert_params__ + self.model.__params__
  216. return params_to_string(total_params) if as_string else total_params
  217. def is_expert_tensor_parallelism_enabled(self):
  218. for _, module in self.model.named_modules():
  219. if isinstance(module, MoE) and hasattr(module, 'enable_expert_tensor_parallelism'):
  220. return module.enable_expert_tensor_parallelism
  221. return False
  222. def print_model_profile(self, profile_step=1, module_depth=-1, top_modules=1, detailed=True, output_file=None):
  223. """Prints the model graph with the measured profile attached to each module.
  224. Args:
  225. profile_step (int, optional): The global training step at which to profile. Note that warm up steps are needed for accurate time measurement.
  226. module_depth (int, optional): The depth of the model to which to print the aggregated module information. When set to -1, it prints information from the top to the innermost modules (the maximum depth).
  227. top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified.
  228. detailed (bool, optional): Whether to print the detailed model profile.
  229. output_file (str, optional): Path to the output file. If None, the profiler prints to stdout.
  230. """
  231. if not self.started:
  232. return
  233. import sys
  234. import os.path
  235. original_stdout = None
  236. f = None
  237. if output_file and output_file != "":
  238. dir_path = os.path.dirname(os.path.abspath(output_file))
  239. if not os.path.exists(dir_path):
  240. os.makedirs(dir_path)
  241. original_stdout = sys.stdout
  242. f = open(output_file, "w")
  243. sys.stdout = f
  244. total_flops = self.get_total_flops()
  245. total_macs = self.get_total_macs()
  246. total_duration = self.get_total_duration()
  247. total_params = self.get_total_params()
  248. expert_tensor_parallelism = None # silence the linters
  249. total_model_expert_params = total_model_nonexpert_params = 0
  250. if self.ds_engine:
  251. total_model_nonexpert_params = self.model.__params__ * self.ds_engine.mp_world_size
  252. if self.ds_engine.has_moe_layers:
  253. expert_tensor_parallelism = self.ds_engine.mp_world_size if self.is_expert_tensor_parallelism_enabled(
  254. ) else 1
  255. total_model_expert_params = self.model.__model_expert_params__ * expert_tensor_parallelism
  256. self.flops = total_flops
  257. self.macs = total_macs
  258. self.params = total_params
  259. print("\n-------------------------- DeepSpeed Flops Profiler --------------------------")
  260. print(f'Profile Summary at step {profile_step}:')
  261. print("Notations:\n"
  262. "data parallel size (dp_size), model parallel size(mp_size),\n"
  263. "number of parameters (params), number of multiply-accumulate operations(MACs),\n"
  264. "number of floating-point operations (flops), floating-point operations per second (FLOPS),\n"
  265. "fwd latency (forward propagation latency), bwd latency (backward propagation latency),\n"
  266. "step (weights update latency), iter latency (sum of fwd, bwd and step latency)\n")
  267. line_fmt = '{:<70} {:<8}'
  268. if self.ds_engine:
  269. print(line_fmt.format('world size: ', self.ds_engine.world_size))
  270. print(line_fmt.format('data parallel size: ', self.ds_engine.dp_world_size))
  271. print(line_fmt.format('model parallel size: ', self.ds_engine.mp_world_size))
  272. print(line_fmt.format('batch size per GPU: ', self.ds_engine.train_micro_batch_size_per_gpu()))
  273. if self.ds_engine.has_moe_layers:
  274. print(line_fmt.format('expert tensor parallelism enabled: ', expert_tensor_parallelism > 1))
  275. print(line_fmt.format('params per GPU: ', params_to_string(total_params)))
  276. if total_model_expert_params > 0:
  277. print(
  278. line_fmt.format('params of model: ',
  279. params_to_string(total_model_nonexpert_params + total_model_expert_params)))
  280. print(line_fmt.format(' non-expert params of model: ', params_to_string(total_model_nonexpert_params)))
  281. print(line_fmt.format(' expert params of model: ', params_to_string(total_model_expert_params)))
  282. else:
  283. print(
  284. line_fmt.format('params of model = params per GPU * mp_size: ',
  285. params_to_string(total_model_nonexpert_params)))
  286. print(line_fmt.format('fwd MACs per GPU: ', macs_to_string(total_macs)))
  287. print(line_fmt.format('fwd flops per GPU: ', number_to_string(total_flops)))
  288. print(
  289. line_fmt.format('fwd flops of model = fwd flops per GPU * mp_size: ',
  290. number_to_string(total_flops * (self.ds_engine.mp_world_size if self.ds_engine else 1))))
  291. fwd_latency = self.get_total_duration()
  292. if self.ds_engine and self.ds_engine.wall_clock_breakdown():
  293. fwd_latency = self.ds_engine.timers(FORWARD_GLOBAL_TIMER).elapsed(False) / 1000.0
  294. print(line_fmt.format('fwd latency: ', duration_to_string(fwd_latency)))
  295. print(
  296. line_fmt.format('fwd FLOPS per GPU = fwd flops per GPU / fwd latency: ',
  297. flops_to_string(total_flops / fwd_latency)))
  298. if self.ds_engine and self.ds_engine.wall_clock_breakdown():
  299. bwd_factor = 2 + self.recompute_fwd_factor
  300. bwd_latency = self.ds_engine.timers(BACKWARD_GLOBAL_TIMER).elapsed(False) / 1000.0
  301. step_latency = self.ds_engine.timers(STEP_GLOBAL_TIMER).elapsed(False) / 1000.0
  302. print(line_fmt.format('bwd latency: ', duration_to_string(bwd_latency)))
  303. print(
  304. line_fmt.format(f'bwd FLOPS per GPU = {bwd_factor:g} * fwd flops per GPU / bwd latency: ',
  305. flops_to_string(bwd_factor * total_flops / bwd_latency)))
  306. print(
  307. line_fmt.format(
  308. f'fwd+bwd FLOPS per GPU = {bwd_factor + 1:g} * fwd flops per GPU / (fwd+bwd latency): ',
  309. flops_to_string((bwd_factor + 1) * total_flops / (fwd_latency + bwd_latency))))
  310. print(line_fmt.format('step latency: ', duration_to_string(step_latency)))
  311. iter_latency = fwd_latency + bwd_latency + step_latency
  312. print(line_fmt.format('iter latency: ', duration_to_string(iter_latency)))
  313. print(
  314. line_fmt.format(f'FLOPS per GPU = {bwd_factor + 1:g} * fwd flops per GPU / iter latency: ',
  315. flops_to_string((bwd_factor + 1) * total_flops / iter_latency)))
  316. samples_per_iter = self.ds_engine.train_micro_batch_size_per_gpu() * self.ds_engine.world_size
  317. print(line_fmt.format('samples/second: ', round(samples_per_iter / iter_latency, DEFAULT_PRECISION)))
  318. def flops_repr(module):
  319. params = module.__params__ + module.__expert_params__
  320. flops = get_module_flops(module)
  321. macs = get_module_macs(module)
  322. duration = get_module_duration(module)
  323. items = [
  324. "{} = {:g}% Params".format(
  325. params_to_string(params),
  326. round(100 * params / total_params, DEFAULT_PRECISION) if total_params else 0),
  327. "{} = {:g}% MACs".format(macs_to_string(macs),
  328. round(100 * macs / total_macs, DEFAULT_PRECISION) if total_macs else 0),
  329. "{} = {:g}% latency".format(
  330. duration_to_string(duration),
  331. round(100 * duration / total_duration, DEFAULT_PRECISION) if total_duration else 0),
  332. flops_to_string(round(flops / duration, DEFAULT_PRECISION) if duration else 0),
  333. ]
  334. original_extra_repr = module.original_extra_repr()
  335. if original_extra_repr:
  336. items.append(original_extra_repr)
  337. return ", ".join(items)
  338. def add_extra_repr(module):
  339. flops_extra_repr = flops_repr.__get__(module)
  340. if module.extra_repr != flops_extra_repr:
  341. module.original_extra_repr = module.extra_repr
  342. module.extra_repr = flops_extra_repr
  343. assert module.extra_repr != module.original_extra_repr
  344. def del_extra_repr(module):
  345. if hasattr(module, "original_extra_repr"):
  346. module.extra_repr = module.original_extra_repr
  347. del module.original_extra_repr
  348. self.model.apply(add_extra_repr)
  349. print("\n----------------------------- Aggregated Profile per GPU -----------------------------")
  350. self.print_model_aggregated_profile(module_depth=module_depth, top_modules=top_modules)
  351. if detailed:
  352. print("\n------------------------------ Detailed Profile per GPU ------------------------------")
  353. print(
  354. "Each module profile is listed after its name in the following order: \nparams, percentage of total params, MACs, percentage of total MACs, fwd latency, percentage of total fwd latency, fwd FLOPS"
  355. )
  356. print(
  357. "\nNote: 1. A module can have torch.nn.module or torch.nn.functional to compute logits (e.g. CrossEntropyLoss). They are not counted as submodules, thus not to be printed out. However they make up the difference between a parent's MACs (or latency) and the sum of its submodules'.\n2. Number of floating-point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput.\n3. The fwd latency listed in the top module's profile is directly captured at the module forward function in PyTorch, thus it's less than the fwd latency shown above which is captured in DeepSpeed.\n"
  358. )
  359. print(self.model)
  360. self.model.apply(del_extra_repr)
  361. print("------------------------------------------------------------------------------")
  362. if output_file:
  363. sys.stdout = original_stdout
  364. f.close()
  365. def print_model_aggregated_profile(self, module_depth=-1, top_modules=1):
  366. """Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth.
  367. Args:
  368. module_depth (int, optional): the depth of the modules to show. Defaults to -1 (the innermost modules).
  369. top_modules (int, optional): the number of top modules to show. Defaults to 1.
  370. """
  371. info = {}
  372. if not hasattr(self.model, "__flops__"):
  373. print("no __flops__ attribute in the model, call this function after start_profile and before end_profile")
  374. return
  375. def walk_module(module, curr_depth, info):
  376. if curr_depth not in info:
  377. info[curr_depth] = {}
  378. if module.__class__.__name__ not in info[curr_depth]:
  379. info[curr_depth][module.__class__.__name__] = [
  380. 0,
  381. 0,
  382. 0,
  383. ] # macs, params, time
  384. info[curr_depth][module.__class__.__name__][0] += get_module_macs(module)
  385. info[curr_depth][module.__class__.__name__][1] += module.__params__ + module.__expert_params__
  386. info[curr_depth][module.__class__.__name__][2] += get_module_duration(module)
  387. has_children = len(module._modules.items()) != 0
  388. if has_children:
  389. for child in module.children():
  390. walk_module(child, curr_depth + 1, info)
  391. walk_module(self.model, 0, info)
  392. depth = module_depth
  393. if module_depth == -1:
  394. depth = len(info) - 1
  395. print(f'Top {top_modules} modules in terms of params, MACs or fwd latency at different model depths:')
  396. for d in range(depth):
  397. num_items = min(top_modules, len(info[d]))
  398. sort_macs = {
  399. k: macs_to_string(v[0])
  400. for k, v in sorted(info[d].items(), key=lambda item: item[1][0], reverse=True)[:num_items]
  401. }
  402. sort_params = {
  403. k: params_to_string(v[1])
  404. for k, v in sorted(info[d].items(), key=lambda item: item[1][1], reverse=True)[:num_items]
  405. }
  406. sort_time = {
  407. k: duration_to_string(v[2])
  408. for k, v in sorted(info[d].items(), key=lambda item: item[1][2], reverse=True)[:num_items]
  409. }
  410. print(f"depth {d}:")
  411. print(f" params - {sort_params}")
  412. print(f" MACs - {sort_macs}")
  413. print(f" fwd latency - {sort_time}")
  414. def _prod(dims):
  415. p = 1
  416. for v in dims:
  417. p *= v
  418. return p
  419. def _linear_flops_compute(input, weight, bias=None):
  420. out_features = weight.shape[0]
  421. macs = input.numel() * out_features
  422. return 2 * macs, macs
  423. def _relu_flops_compute(input, inplace=False):
  424. return input.numel(), 0
  425. def _prelu_flops_compute(input: Tensor, weight: Tensor):
  426. return input.numel(), 0
  427. def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False):
  428. return input.numel(), 0
  429. def _leaky_relu_flops_compute(input: Tensor, negative_slope: float = 0.01, inplace: bool = False):
  430. return input.numel(), 0
  431. def _relu6_flops_compute(input: Tensor, inplace: bool = False):
  432. return input.numel(), 0
  433. def _silu_flops_compute(input: Tensor, inplace: bool = False):
  434. return input.numel(), 0
  435. def _gelu_flops_compute(input, **kwargs):
  436. return input.numel(), 0
  437. def _pool_flops_compute(input,
  438. kernel_size,
  439. stride=None,
  440. padding=0,
  441. dilation=None,
  442. ceil_mode=False,
  443. count_include_pad=True,
  444. divisor_override=None,
  445. return_indices=None):
  446. return input.numel(), 0
  447. def _conv_flops_compute(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
  448. assert weight.shape[1] * groups == input.shape[1]
  449. batch_size = input.shape[0]
  450. in_channels = input.shape[1]
  451. out_channels = weight.shape[0]
  452. kernel_dims = list(weight.shape[2:])
  453. input_dims = list(input.shape[2:])
  454. length = len(input_dims)
  455. strides = stride if type(stride) is tuple else (stride, ) * length
  456. dilations = dilation if type(dilation) is tuple else (dilation, ) * length
  457. if isinstance(padding, str):
  458. if padding == 'valid':
  459. paddings = (0, ) * length
  460. elif padding == 'same':
  461. paddings = ()
  462. for d, k in zip(dilations, kernel_dims):
  463. total_padding = d * (k - 1)
  464. paddings += (total_padding // 2, )
  465. elif isinstance(padding, tuple):
  466. paddings = padding
  467. else:
  468. paddings = (padding, ) * length
  469. output_dims = []
  470. for idx, input_dim in enumerate(input_dims):
  471. output_dim = (input_dim + 2 * paddings[idx] - (dilations[idx] *
  472. (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1
  473. output_dims.append(output_dim)
  474. filters_per_channel = out_channels // groups
  475. conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
  476. active_elements_count = batch_size * int(_prod(output_dims))
  477. overall_conv_macs = conv_per_position_macs * active_elements_count
  478. overall_conv_flops = 2 * overall_conv_macs
  479. bias_flops = 0
  480. if bias is not None:
  481. bias_flops = out_channels * active_elements_count
  482. return int(overall_conv_flops + bias_flops), int(overall_conv_macs)
  483. def _conv_trans_flops_compute(
  484. input,
  485. weight,
  486. bias=None,
  487. stride=1,
  488. padding=0,
  489. output_padding=0,
  490. groups=1,
  491. dilation=1,
  492. ):
  493. batch_size = input.shape[0]
  494. in_channels = input.shape[1]
  495. out_channels = weight.shape[1]
  496. kernel_dims = list(weight.shape[2:])
  497. input_dims = list(input.shape[2:])
  498. length = len(input_dims)
  499. paddings = padding if type(padding) is tuple else (padding, ) * length
  500. strides = stride if type(stride) is tuple else (stride, ) * length
  501. dilations = dilation if type(dilation) is tuple else (dilation, ) * length
  502. output_dims = []
  503. for idx, input_dim in enumerate(input_dims):
  504. output_dim = (input_dim + 2 * paddings[idx] - (dilations[idx] *
  505. (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1
  506. output_dims.append(output_dim)
  507. paddings = padding if type(padding) is tuple else (padding, padding)
  508. strides = stride if type(stride) is tuple else (stride, stride)
  509. dilations = dilation if type(dilation) is tuple else (dilation, dilation)
  510. filters_per_channel = out_channels // groups
  511. conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
  512. active_elements_count = batch_size * int(_prod(input_dims))
  513. overall_conv_macs = conv_per_position_macs * active_elements_count
  514. overall_conv_flops = 2 * overall_conv_macs
  515. bias_flops = 0
  516. if bias is not None:
  517. bias_flops = out_channels * batch_size * int(_prod(output_dims))
  518. return int(overall_conv_flops + bias_flops), int(overall_conv_macs)
  519. def _batch_norm_flops_compute(
  520. input,
  521. running_mean,
  522. running_var,
  523. weight=None,
  524. bias=None,
  525. training=False,
  526. momentum=0.1,
  527. eps=1e-05,
  528. ):
  529. has_affine = weight is not None
  530. if training:
  531. # estimation
  532. return input.numel() * (5 if has_affine else 4), 0
  533. flops = input.numel() * (2 if has_affine else 1)
  534. return flops, 0
  535. def _layer_norm_flops_compute(
  536. input: Tensor,
  537. normalized_shape: List[int],
  538. weight: Optional[Tensor] = None,
  539. bias: Optional[Tensor] = None,
  540. eps: float = 1e-5,
  541. ):
  542. has_affine = weight is not None
  543. # estimation
  544. return input.numel() * (5 if has_affine else 4), 0
  545. def _group_norm_flops_compute(input: Tensor,
  546. num_groups: int,
  547. weight: Optional[Tensor] = None,
  548. bias: Optional[Tensor] = None,
  549. eps: float = 1e-5):
  550. has_affine = weight is not None
  551. # estimation
  552. return input.numel() * (5 if has_affine else 4), 0
  553. def _instance_norm_flops_compute(
  554. input: Tensor,
  555. running_mean: Optional[Tensor] = None,
  556. running_var: Optional[Tensor] = None,
  557. weight: Optional[Tensor] = None,
  558. bias: Optional[Tensor] = None,
  559. use_input_stats: bool = True,
  560. momentum: float = 0.1,
  561. eps: float = 1e-5,
  562. ):
  563. has_affine = weight is not None
  564. # estimation
  565. return input.numel() * (5 if has_affine else 4), 0
  566. def _upsample_flops_compute(*args, **kwargs):
  567. input = args[0]
  568. size = kwargs.get('size', None)
  569. if size is None and len(args) > 1:
  570. size = args[1]
  571. if size is not None:
  572. if isinstance(size, tuple) or isinstance(size, list):
  573. return int(_prod(size)), 0
  574. else:
  575. return int(size), 0
  576. scale_factor = kwargs.get('scale_factor', None)
  577. if scale_factor is None and len(args) > 2:
  578. scale_factor = args[2]
  579. assert scale_factor is not None, "either size or scale_factor should be defined"
  580. flops = input.numel()
  581. if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
  582. flops *= int(_prod(scale_factor))
  583. else:
  584. flops *= scale_factor**len(input)
  585. return flops, 0
  586. def _softmax_flops_compute(input, dim=None, _stacklevel=3, dtype=None):
  587. return input.numel(), 0
  588. def _embedding_flops_compute(
  589. input,
  590. weight,
  591. padding_idx=None,
  592. max_norm=None,
  593. norm_type=2.0,
  594. scale_grad_by_freq=False,
  595. sparse=False,
  596. ):
  597. return 0, 0
  598. def _dropout_flops_compute(input, p=0.5, training=True, inplace=False):
  599. return 0, 0
  600. def _matmul_flops_compute(input, other, *, out=None):
  601. """
  602. Count flops for the matmul operation.
  603. """
  604. macs = _prod(input.shape) * other.shape[-1]
  605. return 2 * macs, macs
  606. def _addmm_flops_compute(input, mat1, mat2, *, beta=1, alpha=1, out=None):
  607. """
  608. Count flops for the addmm operation.
  609. """
  610. macs = _prod(mat1.shape) * mat2.shape[-1]
  611. return 2 * macs + _prod(input.shape), macs
  612. def _einsum_flops_compute(equation, *operands):
  613. """
  614. Count flops for the einsum operation.
  615. """
  616. equation = equation.replace(" ", "")
  617. input_shapes = [o.shape for o in operands]
  618. # Re-map equation so that same equation with different alphabet
  619. # representations will look the same.
  620. letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys()
  621. mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)}
  622. equation = equation.translate(mapping)
  623. np_arrs = [np.zeros(s) for s in input_shapes]
  624. optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
  625. for line in optim.split("\n"):
  626. if "optimized flop" in line.lower():
  627. flop = int(float(line.split(":")[-1]))
  628. return flop, 0
  629. raise NotImplementedError("Unsupported einsum operation.")
  630. def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None):
  631. """
  632. Count flops for the tensor addmm operation.
  633. """
  634. macs = _prod(mat1.shape) * mat2.shape[-1]
  635. return 2 * macs + _prod(self.shape), macs
  636. def _mul_flops_compute(input, other, *, out=None):
  637. return _elementwise_flops_compute(input, other)
  638. def _add_flops_compute(input, other, *, alpha=1, out=None):
  639. return _elementwise_flops_compute(input, other)
  640. def _elementwise_flops_compute(input, other):
  641. if not torch.is_tensor(input):
  642. if torch.is_tensor(other):
  643. return _prod(other.shape), 0
  644. else:
  645. return 1, 0
  646. elif not torch.is_tensor(other):
  647. return _prod(input.shape), 0
  648. else:
  649. dim_input = len(input.shape)
  650. dim_other = len(other.shape)
  651. max_dim = max(dim_input, dim_other)
  652. final_shape = []
  653. for i in range(max_dim):
  654. in_i = input.shape[i] if i < dim_input else 1
  655. ot_i = other.shape[i] if i < dim_other else 1
  656. if in_i > ot_i:
  657. final_shape.append(in_i)
  658. else:
  659. final_shape.append(ot_i)
  660. flops = _prod(final_shape)
  661. return flops, 0
  662. def wrapFunc(func, funcFlopCompute):
  663. oldFunc = func
  664. name = func.__str__
  665. old_functions[name] = oldFunc
  666. def newFunc(*args, **kwds):
  667. flops, macs = funcFlopCompute(*args, **kwds)
  668. if module_flop_count:
  669. module_flop_count[-1].append((name, flops))
  670. if module_mac_count and macs:
  671. module_mac_count[-1].append((name, macs))
  672. return oldFunc(*args, **kwds)
  673. newFunc.__str__ = func.__str__
  674. return newFunc
  675. def _patch_functionals():
  676. # FC
  677. F.linear = wrapFunc(F.linear, _linear_flops_compute)
  678. # convolutions
  679. F.conv1d = wrapFunc(F.conv1d, _conv_flops_compute)
  680. F.conv2d = wrapFunc(F.conv2d, _conv_flops_compute)
  681. F.conv3d = wrapFunc(F.conv3d, _conv_flops_compute)
  682. # conv transposed
  683. F.conv_transpose1d = wrapFunc(F.conv_transpose1d, _conv_trans_flops_compute)
  684. F.conv_transpose2d = wrapFunc(F.conv_transpose2d, _conv_trans_flops_compute)
  685. F.conv_transpose3d = wrapFunc(F.conv_transpose3d, _conv_trans_flops_compute)
  686. # activations
  687. F.relu = wrapFunc(F.relu, _relu_flops_compute)
  688. F.prelu = wrapFunc(F.prelu, _prelu_flops_compute)
  689. F.elu = wrapFunc(F.elu, _elu_flops_compute)
  690. F.leaky_relu = wrapFunc(F.leaky_relu, _leaky_relu_flops_compute)
  691. F.relu6 = wrapFunc(F.relu6, _relu6_flops_compute)
  692. if hasattr(F, "silu"):
  693. F.silu = wrapFunc(F.silu, _silu_flops_compute)
  694. F.gelu = wrapFunc(F.gelu, _gelu_flops_compute)
  695. # Normalizations
  696. F.batch_norm = wrapFunc(F.batch_norm, _batch_norm_flops_compute)
  697. F.layer_norm = wrapFunc(F.layer_norm, _layer_norm_flops_compute)
  698. F.instance_norm = wrapFunc(F.instance_norm, _instance_norm_flops_compute)
  699. F.group_norm = wrapFunc(F.group_norm, _group_norm_flops_compute)
  700. # poolings
  701. F.avg_pool1d = wrapFunc(F.avg_pool1d, _pool_flops_compute)
  702. F.avg_pool2d = wrapFunc(F.avg_pool2d, _pool_flops_compute)
  703. F.avg_pool3d = wrapFunc(F.avg_pool3d, _pool_flops_compute)
  704. F.max_pool1d = wrapFunc(F.max_pool1d, _pool_flops_compute)
  705. F.max_pool2d = wrapFunc(F.max_pool2d, _pool_flops_compute)
  706. F.max_pool3d = wrapFunc(F.max_pool3d, _pool_flops_compute)
  707. F.adaptive_avg_pool1d = wrapFunc(F.adaptive_avg_pool1d, _pool_flops_compute)
  708. F.adaptive_avg_pool2d = wrapFunc(F.adaptive_avg_pool2d, _pool_flops_compute)
  709. F.adaptive_avg_pool3d = wrapFunc(F.adaptive_avg_pool3d, _pool_flops_compute)
  710. F.adaptive_max_pool1d = wrapFunc(F.adaptive_max_pool1d, _pool_flops_compute)
  711. F.adaptive_max_pool2d = wrapFunc(F.adaptive_max_pool2d, _pool_flops_compute)
  712. F.adaptive_max_pool3d = wrapFunc(F.adaptive_max_pool3d, _pool_flops_compute)
  713. # upsample
  714. F.upsample = wrapFunc(F.upsample, _upsample_flops_compute)
  715. F.interpolate = wrapFunc(F.interpolate, _upsample_flops_compute)
  716. # softmax
  717. F.softmax = wrapFunc(F.softmax, _softmax_flops_compute)
  718. # embedding
  719. F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
  720. def _patch_tensor_methods():
  721. torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute)
  722. torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute)
  723. torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
  724. torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
  725. torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
  726. torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute)
  727. torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
  728. torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute)
  729. torch.mul = wrapFunc(torch.mul, _mul_flops_compute)
  730. torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute)
  731. torch.add = wrapFunc(torch.add, _add_flops_compute)
  732. torch.Tensor.add = wrapFunc(torch.Tensor.add, _add_flops_compute)
  733. torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute)
  734. torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute)
  735. def _reload_functionals():
  736. # torch.nn.functional does not support importlib.reload()
  737. F.linear = old_functions[F.linear.__str__]
  738. F.conv1d = old_functions[F.conv1d.__str__]
  739. F.conv2d = old_functions[F.conv2d.__str__]
  740. F.conv3d = old_functions[F.conv3d.__str__]
  741. F.conv_transpose1d = old_functions[F.conv_transpose1d.__str__]
  742. F.conv_transpose2d = old_functions[F.conv_transpose2d.__str__]
  743. F.conv_transpose3d = old_functions[F.conv_transpose3d.__str__]
  744. F.relu = old_functions[F.relu.__str__]
  745. F.prelu = old_functions[F.prelu.__str__]
  746. F.elu = old_functions[F.elu.__str__]
  747. F.leaky_relu = old_functions[F.leaky_relu.__str__]
  748. F.relu6 = old_functions[F.relu6.__str__]
  749. if hasattr(F, "silu"):
  750. F.silu = old_functions[F.silu.__str__]
  751. F.gelu = old_functions[F.gelu.__str__]
  752. F.batch_norm = old_functions[F.batch_norm.__str__]
  753. F.layer_norm = old_functions[F.layer_norm.__str__]
  754. F.instance_norm = old_functions[F.instance_norm.__str__]
  755. F.group_norm = old_functions[F.group_norm.__str__]
  756. F.avg_pool1d = old_functions[F.avg_pool1d.__str__]
  757. F.avg_pool2d = old_functions[F.avg_pool2d.__str__]
  758. F.avg_pool3d = old_functions[F.avg_pool3d.__str__]
  759. F.max_pool1d = old_functions[F.max_pool1d.__str__]
  760. F.max_pool2d = old_functions[F.max_pool2d.__str__]
  761. F.max_pool3d = old_functions[F.max_pool3d.__str__]
  762. F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__str__]
  763. F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__str__]
  764. F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__str__]
  765. F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__str__]
  766. F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__str__]
  767. F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__str__]
  768. F.upsample = old_functions[F.upsample.__str__]
  769. F.interpolate = old_functions[F.interpolate.__str__]
  770. F.softmax = old_functions[F.softmax.__str__]
  771. F.embedding = old_functions[F.embedding.__str__]
  772. def _reload_tensor_methods():
  773. torch.matmul = old_functions[torch.matmul.__str__]
  774. torch.Tensor.matmul = old_functions[torch.Tensor.matmul.__str__]
  775. torch.mm = old_functions[torch.mm.__str__]
  776. torch.Tensor.mm = old_functions[torch.Tensor.mm.__str__]
  777. torch.bmm = old_functions[torch.matmul.__str__]
  778. torch.Tensor.bmm = old_functions[torch.Tensor.bmm.__str__]
  779. torch.addmm = old_functions[torch.addmm.__str__]
  780. torch.Tensor.addmm = old_functions[torch.Tensor.addmm.__str__]
  781. torch.mul = old_functions[torch.mul.__str__]
  782. torch.Tensor.mul = old_functions[torch.Tensor.mul.__str__]
  783. torch.add = old_functions[torch.add.__str__]
  784. torch.Tensor.add = old_functions[torch.Tensor.add.__str__]
  785. torch.einsum = old_functions[torch.einsum.__str__]
  786. torch.baddbmm = old_functions[torch.baddbmm.__str__]
  787. def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
  788. gates_size = w_ih.shape[0]
  789. # matrix matrix mult ih state and internal state
  790. flops += 2 * w_ih.shape[0] * w_ih.shape[1] - gates_size
  791. # matrix matrix mult hh state and internal state
  792. flops += 2 * w_hh.shape[0] * w_hh.shape[1] - gates_size
  793. if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
  794. # add both operations
  795. flops += rnn_module.hidden_size
  796. elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
  797. # hadamard of r
  798. flops += rnn_module.hidden_size
  799. # adding operations from both states
  800. flops += rnn_module.hidden_size * 3
  801. # last two hadamard _product and add
  802. flops += rnn_module.hidden_size * 3
  803. elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
  804. # adding operations from both states
  805. flops += rnn_module.hidden_size * 4
  806. # two hadamard _product and add for C state
  807. flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
  808. # final hadamard
  809. flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
  810. return flops
  811. def _rnn_forward_hook(rnn_module, input, output):
  812. flops = 0
  813. # input is a tuple containing a sequence to process and (optionally) hidden state
  814. inp = input[0]
  815. batch_size = inp.shape[0]
  816. seq_length = inp.shape[1]
  817. num_layers = rnn_module.num_layers
  818. for i in range(num_layers):
  819. w_ih = rnn_module.__getattr__("weight_ih_l" + str(i))
  820. w_hh = rnn_module.__getattr__("weight_hh_l" + str(i))
  821. if i == 0:
  822. input_size = rnn_module.input_size
  823. else:
  824. input_size = rnn_module.hidden_size
  825. flops = _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)
  826. if rnn_module.bias:
  827. b_ih = rnn_module.__getattr__("bias_ih_l" + str(i))
  828. b_hh = rnn_module.__getattr__("bias_hh_l" + str(i))
  829. flops += b_ih.shape[0] + b_hh.shape[0]
  830. flops *= batch_size
  831. flops *= seq_length
  832. if rnn_module.bidirectional:
  833. flops *= 2
  834. rnn_module.__flops__ += int(flops)
  835. def _rnn_cell_forward_hook(rnn_cell_module, input, output):
  836. flops = 0
  837. inp = input[0]
  838. batch_size = inp.shape[0]
  839. w_ih = rnn_cell_module.__getattr__("weight_ih")
  840. w_hh = rnn_cell_module.__getattr__("weight_hh")
  841. input_size = inp.shape[1]
  842. flops = _rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)
  843. if rnn_cell_module.bias:
  844. b_ih = rnn_cell_module.__getattr__("bias_ih")
  845. b_hh = rnn_cell_module.__getattr__("bias_hh")
  846. flops += b_ih.shape[0] + b_hh.shape[0]
  847. flops *= batch_size
  848. rnn_cell_module.__flops__ += int(flops)
  849. MODULE_HOOK_MAPPING = {
  850. # RNN
  851. nn.RNN: _rnn_forward_hook,
  852. nn.GRU: _rnn_forward_hook,
  853. nn.LSTM: _rnn_forward_hook,
  854. nn.RNNCell: _rnn_cell_forward_hook,
  855. nn.LSTMCell: _rnn_cell_forward_hook,
  856. nn.GRUCell: _rnn_cell_forward_hook,
  857. }
  858. def macs_to_string(macs, units=None, precision=DEFAULT_PRECISION):
  859. return f"{number_to_string(macs, units=units, precision=precision)}MACs"
  860. def number_to_string(num, units=None, precision=DEFAULT_PRECISION):
  861. if units is None:
  862. if num >= 1e12:
  863. magnitude, units = 1e12, "T"
  864. elif num >= 1e9:
  865. magnitude, units = 1e9, "G"
  866. elif num >= 1e6:
  867. magnitude, units = 1e6, "M"
  868. elif num >= 1e3:
  869. magnitude, units = 1e3, "K"
  870. elif num >= 1 or num == 0:
  871. magnitude, units = 1, ""
  872. elif num >= 1e-3:
  873. magnitude, units = 1e-3, "m"
  874. else:
  875. magnitude, units = 1e-6, "u"
  876. else:
  877. if units == "T":
  878. magnitude = 1e12
  879. elif units == "G":
  880. magnitude = 1e9
  881. elif units == "M":
  882. magnitude = 1e6
  883. elif units == "K":
  884. magnitude = 1e3
  885. elif units == "m":
  886. magnitude = 1e-3
  887. elif units == "u":
  888. magnitude = 1e-6
  889. else:
  890. magnitude = 1
  891. return f"{round(num / magnitude, precision):g} {units}"
  892. def flops_to_string(flops, units=None, precision=DEFAULT_PRECISION):
  893. return f"{number_to_string(flops, units=units, precision=precision)}FLOPS"
  894. def bytes_to_string(b, units=None, precision=DEFAULT_PRECISION):
  895. return f"{number_to_string(b, units=units, precision=precision)}B"
  896. def params_to_string(params_num, units=None, precision=DEFAULT_PRECISION):
  897. units = units.replace("B", "G") if units else units
  898. return number_to_string(params_num, units=units, precision=precision).replace("G", "B").strip()
  899. def duration_to_string(duration, units=None, precision=DEFAULT_PRECISION):
  900. return f"{number_to_string(duration, units=units, precision=precision)}s"
  901. # can not iterate over all submodules using self.model.modules()
  902. # since modules() returns duplicate modules only once
  903. def get_module_flops(module):
  904. sum = module.__flops__
  905. # iterate over immediate children modules
  906. for child in module.children():
  907. sum += get_module_flops(child)
  908. return sum
  909. def get_module_macs(module):
  910. sum = module.__macs__
  911. # iterate over immediate children modules
  912. for child in module.children():
  913. sum += get_module_macs(child)
  914. return sum
  915. def get_module_duration(module):
  916. duration = module.__duration__
  917. if duration == 0: # e.g. ModuleList
  918. for m in module.children():
  919. duration += get_module_duration(m)
  920. return duration
  921. def get_model_profile(model,
  922. input_shape=None,
  923. args=[],
  924. kwargs={},
  925. print_profile=True,
  926. detailed=True,
  927. module_depth=-1,
  928. top_modules=1,
  929. warm_up=1,
  930. as_string=True,
  931. output_file=None,
  932. ignore_modules=None,
  933. mode='forward'):
  934. """Returns the total floating-point operations, MACs, and parameters of a model.
  935. Example:
  936. .. code-block:: python
  937. model = torchvision.models.alexnet()
  938. batch_size = 256
  939. flops, macs, params = get_model_profile(model=model, input_shape=(batch_size, 3, 224, 224)))
  940. Args:
  941. model ([torch.nn.Module]): the PyTorch model to be profiled.
  942. input_shape (tuple): input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
  943. args (list): list of positional arguments to the model.
  944. kwargs (dict): dictionary of keyword arguments to the model.
  945. print_profile (bool, optional): whether to print the model profile. Defaults to True.
  946. detailed (bool, optional): whether to print the detailed model profile. Defaults to True.
  947. module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules).
  948. top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3.
  949. warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
  950. as_string (bool, optional): whether to print the output as string. Defaults to True.
  951. output_file (str, optional): path to the output file. If None, the profiler prints to stdout.
  952. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
  953. Returns:
  954. The number of floating-point operations, multiply-accumulate operations (MACs), and parameters in the model.
  955. """
  956. assert isinstance(model, nn.Module), "model must be a PyTorch module"
  957. prof = FlopsProfiler(model)
  958. model.eval()
  959. if input_shape is not None:
  960. assert type(input_shape) is tuple, "input_shape must be a tuple"
  961. assert len(input_shape) >= 1, "input_shape must have at least one element"
  962. try:
  963. input = torch.ones(()).new_empty(
  964. (*input_shape, ),
  965. dtype=next(model.parameters()).dtype,
  966. device=next(model.parameters()).device,
  967. )
  968. except StopIteration:
  969. input = torch.ones(()).new_empty((*input_shape, ))
  970. args = [input]
  971. assert (len(args) > 0) or (len(kwargs) > 0), "args and/or kwargs must be specified if input_shape is None"
  972. logger.info("Flops profiler warming-up...")
  973. for _ in range(warm_up):
  974. if kwargs:
  975. if mode == 'forward':
  976. _ = model(*args, **kwargs)
  977. if mode == 'generate':
  978. _ = model.generate(*args, **kwargs)
  979. else:
  980. if mode == 'forward':
  981. _ = model(*args)
  982. if mode == 'generate':
  983. _ = model.generate(*args)
  984. prof.start_profile(ignore_list=ignore_modules)
  985. if kwargs:
  986. if mode == 'forward':
  987. _ = model(*args, **kwargs)
  988. if mode == 'generate':
  989. _ = model.generate(*args, **kwargs)
  990. else:
  991. if mode == 'forward':
  992. _ = model(*args)
  993. if mode == 'generate':
  994. _ = model.generate(*args)
  995. flops = prof.get_total_flops()
  996. macs = prof.get_total_macs()
  997. params = prof.get_total_params()
  998. if print_profile:
  999. prof.print_model_profile(profile_step=warm_up,
  1000. module_depth=module_depth,
  1001. top_modules=top_modules,
  1002. detailed=detailed,
  1003. output_file=output_file)
  1004. prof.end_profile()
  1005. if as_string:
  1006. return number_to_string(flops), macs_to_string(macs), params_to_string(params)
  1007. return flops, macs, params