profiler.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217
  1. import time
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from functools import partial
  6. from typing import Callable, List, Optional, Tuple
  7. from collections import OrderedDict
  8. import numpy as np
  9. Tensor = torch.Tensor
  10. module_flop_count = []
  11. module_mac_count = []
  12. old_functions = {}
  13. class FlopsProfiler(object):
  14. """Measures the latency, number of estimated floating-point operations and parameters of each module in a PyTorch model.
  15. The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
  16. The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
  17. When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file and no user code change is required.
  18. If using the profiler as a standalone package, one imports the flops_profiler package and use the APIs.
  19. Here is an example for usage in a typical training workflow:
  20. .. code-block:: python
  21. model = Model()
  22. prof = FlopsProfiler(model)
  23. for step, batch in enumerate(data_loader):
  24. if step == profile_step:
  25. prof.start_profile()
  26. loss = model(batch)
  27. if step == profile_step:
  28. flops = prof.get_total_flops(as_string=True)
  29. params = prof.get_total_params(as_string=True)
  30. prof.print_model_profile(profile_step=profile_step)
  31. prof.end_profile()
  32. loss.backward()
  33. optimizer.step()
  34. To profile a trained model in inference, use the `get_model_profile` API.
  35. Args:
  36. object (torch.nn.Module): The PyTorch model to profile.
  37. """
  38. def __init__(self, model, ds_engine=None):
  39. self.model = model
  40. self.ds_engine = ds_engine
  41. self.started = False
  42. self.func_patched = False
  43. def start_profile(self, ignore_list=None):
  44. """Starts profiling.
  45. Extra attributes are added recursively to all the modules and the profiled torch.nn.functionals are monkey patched.
  46. Args:
  47. ignore_list (list, optional): the list of modules to ignore while profiling. Defaults to None.
  48. """
  49. self.reset_profile()
  50. _patch_functionals()
  51. _patch_tensor_methods()
  52. def register_module_hooks(module, ignore_list):
  53. if ignore_list and type(module) in ignore_list:
  54. return
  55. # if computing the flops of a module directly
  56. if type(module) in MODULE_HOOK_MAPPING:
  57. module.__flops_handle__ = module.register_forward_hook(
  58. MODULE_HOOK_MAPPING[type(module)])
  59. return
  60. # if computing the flops of the functionals in a module
  61. def pre_hook(module, input):
  62. module_flop_count.append([])
  63. module_mac_count.append([])
  64. module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
  65. def post_hook(module, input, output):
  66. if module_flop_count:
  67. module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]])
  68. module_flop_count.pop()
  69. module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]])
  70. module_mac_count.pop()
  71. module.__post_hook_handle__ = module.register_forward_hook(post_hook)
  72. def start_time_hook(module, input):
  73. module.__start_time__ = time.time()
  74. module.__start_time_hook_handle__ = module.register_forward_pre_hook(
  75. start_time_hook)
  76. def end_time_hook(module, input, output):
  77. module.__duration__ += time.time() - module.__start_time__
  78. module.__end_time_hook_handle__ = module.register_forward_hook(end_time_hook)
  79. self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
  80. self.started = True
  81. self.func_patched = True
  82. def stop_profile(self):
  83. """Stop profiling.
  84. All torch.nn.functionals are restored to their originals.
  85. """
  86. if self.started and self.func_patched:
  87. _reload_functionals()
  88. _reload_tensor_methods()
  89. self.func_patched = False
  90. def remove_profile_attrs(module):
  91. if hasattr(module, "__pre_hook_handle__"):
  92. module.__pre_hook_handle__.remove()
  93. del module.__pre_hook_handle__
  94. if hasattr(module, "__post_hook_handle__"):
  95. module.__post_hook_handle__.remove()
  96. del module.__post_hook_handle__
  97. if hasattr(module, "__flops_handle__"):
  98. module.__flops_handle__.remove()
  99. del module.__flops_handle__
  100. if hasattr(module, "__start_time_hook_handle__"):
  101. module.__start_time_hook_handle__.remove()
  102. del module.__start_time_hook_handle__
  103. if hasattr(module, "__end_time_hook_handle__"):
  104. module.__end_time_hook_handle__.remove()
  105. del module.__end_time_hook_handle__
  106. self.model.apply(remove_profile_attrs)
  107. def reset_profile(self):
  108. """Resets the profiling.
  109. Adds or resets the extra attributes.
  110. """
  111. def add_or_reset_attrs(module):
  112. module.__flops__ = 0
  113. module.__macs__ = 0
  114. module.__params__ = sum(p.numel() for p in module.parameters()
  115. if p.requires_grad)
  116. module.__start_time__ = 0
  117. module.__duration__ = 0
  118. self.model.apply(add_or_reset_attrs)
  119. def end_profile(self):
  120. """Ends profiling.
  121. The added attributes and handles are removed recursively on all the modules.
  122. """
  123. if not self.started:
  124. return
  125. self.stop_profile()
  126. self.started = False
  127. def remove_profile_attrs(module):
  128. if hasattr(module, "__flops__"):
  129. del module.__flops__
  130. if hasattr(module, "__macs__"):
  131. del module.__macs__
  132. if hasattr(module, "__params__"):
  133. del module.__params__
  134. if hasattr(module, "__start_time__"):
  135. del module.__start_time__
  136. if hasattr(module, "__duration__"):
  137. del module.__duration__
  138. self.model.apply(remove_profile_attrs)
  139. def get_total_flops(self, as_string=False):
  140. """Returns the total flops of the model.
  141. Args:
  142. as_string (bool, optional): whether to output the flops as string. Defaults to False.
  143. Returns:
  144. The number of multiply-accumulate operations of the model forward pass.
  145. """
  146. total_flops = get_module_flops(self.model)
  147. return num_to_string(total_flops) if as_string else total_flops
  148. def get_total_macs(self, as_string=False):
  149. """Returns the total MACs of the model.
  150. Args:
  151. as_string (bool, optional): whether to output the flops as string. Defaults to False.
  152. Returns:
  153. The number of multiply-accumulate operations of the model forward pass.
  154. """
  155. total_macs = get_module_macs(self.model)
  156. return macs_to_string(total_macs) if as_string else total_macs
  157. def get_total_duration(self, as_string=False):
  158. """Returns the total duration of the model forward pass.
  159. Args:
  160. as_string (bool, optional): whether to output the duration as string. Defaults to False.
  161. Returns:
  162. The latency of the model forward pass.
  163. """
  164. total_duration = get_module_duration(self.model)
  165. return duration_to_string(total_duration) if as_string else total_duration
  166. def get_total_params(self, as_string=False):
  167. """Returns the total parameters of the model.
  168. Args:
  169. as_string (bool, optional): whether to output the parameters as string. Defaults to False.
  170. Returns:
  171. The number of parameters in the model.
  172. """
  173. return params_to_string(
  174. self.model.__params__) if as_string else self.model.__params__
  175. def print_model_profile(self,
  176. profile_step=1,
  177. module_depth=-1,
  178. top_modules=1,
  179. detailed=True,
  180. output_file=None):
  181. """Prints the model graph with the measured profile attached to each module.
  182. Args:
  183. profile_step (int, optional): The global training step at which to profile. Note that warm up steps are needed for accurate time measurement.
  184. module_depth (int, optional): The depth of the model to which to print the aggregated module information. When set to -1, it prints information from the top to the innermost modules (the maximum depth).
  185. top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified.
  186. detailed (bool, optional): Whether to print the detailed model profile.
  187. output_file (str, optional): Path to the output file. If None, the profiler prints to stdout.
  188. """
  189. if not self.started:
  190. return
  191. import sys
  192. import os.path
  193. from os import path
  194. original_stdout = None
  195. f = None
  196. if output_file and output_file != "":
  197. dir_path = os.path.dirname(output_file)
  198. if not os.path.exists(dir_path):
  199. os.makedirs(dir_path)
  200. original_stdout = sys.stdout
  201. f = open(output_file, "w")
  202. sys.stdout = f
  203. total_flops = self.get_total_flops()
  204. total_macs = self.get_total_macs()
  205. total_duration = self.get_total_duration()
  206. total_params = self.get_total_params()
  207. self.flops = total_flops
  208. self.macs = total_macs
  209. self.params = total_params
  210. print(
  211. "\n-------------------------- DeepSpeed Flops Profiler --------------------------"
  212. )
  213. print(f'Profile Summary at step {profile_step}:')
  214. print(
  215. "Notations:\ndata parallel size (dp_size), model parallel size(mp_size),\nnumber of parameters (params), number of multiply-accumulate operations(MACs),\nnumber of floating-point operations (flops), floating-point operations per second (FLOPS),\nfwd latency (forward propagation latency), bwd latency (backward propagation latency),\nstep (weights update latency), iter latency (sum of fwd, bwd and step latency)\n"
  216. )
  217. if self.ds_engine:
  218. print('{:<60} {:<8}'.format('world size: ', self.ds_engine.world_size))
  219. print('{:<60} {:<8}'.format('data parallel size: ',
  220. self.ds_engine.dp_world_size))
  221. print('{:<60} {:<8}'.format('model parallel size: ',
  222. self.ds_engine.mp_world_size))
  223. print('{:<60} {:<8}'.format(
  224. 'batch size per GPU: ',
  225. self.ds_engine.train_micro_batch_size_per_gpu()))
  226. print('{:<60} {:<8}'.format('params per gpu: ', params_to_string(total_params)))
  227. print('{:<60} {:<8}'.format(
  228. 'params of model = params per GPU * mp_size: ',
  229. params_to_string(total_params *
  230. (self.ds_engine.mp_world_size) if self.ds_engine else 1)))
  231. print('{:<60} {:<8}'.format('fwd MACs per GPU: ', macs_to_string(total_macs)))
  232. print('{:<60} {:<8}'.format('fwd flops per GPU: ', num_to_string(total_flops)))
  233. print('{:<60} {:<8}'.format(
  234. 'fwd flops of model = fwd flops per GPU * mp_size: ',
  235. num_to_string(total_flops *
  236. (self.ds_engine.mp_world_size) if self.ds_engine else 1)))
  237. fwd_latency = self.get_total_duration()
  238. if self.ds_engine and self.ds_engine.wall_clock_breakdown():
  239. fwd_latency = self.ds_engine.timers('forward').elapsed(False)
  240. print('{:<60} {:<8}'.format('fwd latency: ', duration_to_string(fwd_latency)))
  241. print('{:<60} {:<8}'.format(
  242. 'fwd FLOPS per GPU = fwd flops per GPU / fwd latency: ',
  243. flops_to_string(total_flops / fwd_latency)))
  244. if self.ds_engine and self.ds_engine.wall_clock_breakdown():
  245. bwd_latency = self.ds_engine.timers('backward').elapsed(False)
  246. step_latency = self.ds_engine.timers('step').elapsed(False)
  247. print('{:<60} {:<8}'.format('bwd latency: ',
  248. duration_to_string(bwd_latency)))
  249. print('{:<60} {:<8}'.format(
  250. 'bwd FLOPS per GPU = 2 * fwd flops per GPU / bwd latency: ',
  251. flops_to_string(2 * total_flops / bwd_latency)))
  252. print('{:<60} {:<8}'.format(
  253. 'fwd+bwd FLOPS per GPU = 3 * fwd flops per GPU / (fwd+bwd latency): ',
  254. flops_to_string(3 * total_flops / (fwd_latency + bwd_latency))))
  255. print('{:<60} {:<8}'.format('step latency: ',
  256. duration_to_string(step_latency)))
  257. iter_latency = fwd_latency + bwd_latency + step_latency
  258. print('{:<60} {:<8}'.format('iter latency: ',
  259. duration_to_string(iter_latency)))
  260. print('{:<60} {:<8}'.format(
  261. 'FLOPS per GPU = 3 * fwd flops per GPU / iter latency: ',
  262. flops_to_string(3 * total_flops / iter_latency)))
  263. samples_per_iter = self.ds_engine.train_micro_batch_size_per_gpu(
  264. ) * self.ds_engine.world_size
  265. print('{:<60} {:<8.2f}'.format('samples/second: ',
  266. samples_per_iter / iter_latency))
  267. def flops_repr(module):
  268. params = module.__params__
  269. flops = get_module_flops(module)
  270. macs = get_module_macs(module)
  271. items = [
  272. params_to_string(params),
  273. "{:.2%} Params".format(params / total_params),
  274. macs_to_string(macs),
  275. "{:.2%} MACs".format(0.0 if total_macs == 0 else macs / total_macs),
  276. ]
  277. duration = get_module_duration(module)
  278. items.append(duration_to_string(duration))
  279. items.append(
  280. "{:.2%} latency".format(0.0 if total_duration == 0 else duration /
  281. total_duration))
  282. items.append(flops_to_string(0.0 if duration == 0 else flops / duration))
  283. items.append(module.original_extra_repr())
  284. return ", ".join(items)
  285. def add_extra_repr(module):
  286. flops_extra_repr = flops_repr.__get__(module)
  287. if module.extra_repr != flops_extra_repr:
  288. module.original_extra_repr = module.extra_repr
  289. module.extra_repr = flops_extra_repr
  290. assert module.extra_repr != module.original_extra_repr
  291. def del_extra_repr(module):
  292. if hasattr(module, "original_extra_repr"):
  293. module.extra_repr = module.original_extra_repr
  294. del module.original_extra_repr
  295. self.model.apply(add_extra_repr)
  296. print(
  297. "\n----------------------------- Aggregated Profile per GPU -----------------------------"
  298. )
  299. self.print_model_aggregated_profile(module_depth=module_depth,
  300. top_modules=top_modules)
  301. if detailed:
  302. print(
  303. "\n------------------------------ Detailed Profile per GPU ------------------------------"
  304. )
  305. print(
  306. "Each module profile is listed after its name in the following order: \nparams, percentage of total params, MACs, percentage of total MACs, fwd latency, percentage of total fwd latency, fwd FLOPS"
  307. )
  308. print(
  309. "\nNote: 1. A module can have torch.nn.module or torch.nn.functional to compute logits (e.g. CrossEntropyLoss). They are not counted as submodules, thus not to be printed out. However they make up the difference between a parent's MACs (or latency) and the sum of its submodules'.\n2. Number of floating-point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput.\n3. The fwd latency listed in the top module's profile is directly captured at the module forward function in PyTorch, thus it's less than the fwd latency shown above which is captured in DeepSpeed.\n"
  310. )
  311. print(self.model)
  312. self.model.apply(del_extra_repr)
  313. print(
  314. "------------------------------------------------------------------------------"
  315. )
  316. if output_file:
  317. sys.stdout = original_stdout
  318. f.close()
  319. def print_model_aggregated_profile(self, module_depth=-1, top_modules=1):
  320. """Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth.
  321. Args:
  322. module_depth (int, optional): the depth of the modules to show. Defaults to -1 (the innermost modules).
  323. top_modules (int, optional): the number of top modules to show. Defaults to 1.
  324. """
  325. info = {}
  326. if not hasattr(self.model, "__flops__"):
  327. print(
  328. "no __flops__ attribute in the model, call this function after start_profile and before end_profile"
  329. )
  330. return
  331. def walk_module(module, curr_depth, info):
  332. if curr_depth not in info:
  333. info[curr_depth] = {}
  334. if module.__class__.__name__ not in info[curr_depth]:
  335. info[curr_depth][module.__class__.__name__] = [
  336. 0,
  337. 0,
  338. 0,
  339. ] # macs, params, time
  340. info[curr_depth][module.__class__.__name__][0] += get_module_macs(module)
  341. info[curr_depth][module.__class__.__name__][1] += module.__params__
  342. info[curr_depth][module.__class__.__name__][2] += get_module_duration(module)
  343. has_children = len(module._modules.items()) != 0
  344. if has_children:
  345. for child in module.children():
  346. walk_module(child, curr_depth + 1, info)
  347. walk_module(self.model, 0, info)
  348. depth = module_depth
  349. if module_depth == -1:
  350. depth = len(info) - 1
  351. print(
  352. f'Top {top_modules} modules in terms of params, MACs or fwd latency at different model depths:'
  353. )
  354. for d in range(depth):
  355. num_items = min(top_modules, len(info[d]))
  356. sort_macs = {
  357. k: macs_to_string(v[0])
  358. for k,
  359. v in sorted(info[d].items(),
  360. key=lambda item: item[1][0],
  361. reverse=True)[:num_items]
  362. }
  363. sort_params = {
  364. k: params_to_string(v[1])
  365. for k,
  366. v in sorted(info[d].items(),
  367. key=lambda item: item[1][1],
  368. reverse=True)[:num_items]
  369. }
  370. sort_time = {
  371. k: duration_to_string(v[2])
  372. for k,
  373. v in sorted(info[d].items(),
  374. key=lambda item: item[1][2],
  375. reverse=True)[:num_items]
  376. }
  377. print(f"depth {d}:")
  378. print(f" params - {sort_params}")
  379. print(f" MACs - {sort_macs}")
  380. print(f" fwd latency - {sort_time}")
  381. def _prod(dims):
  382. p = 1
  383. for v in dims:
  384. p *= v
  385. return p
  386. def _linear_flops_compute(input, weight, bias=None):
  387. out_features = weight.shape[0]
  388. macs = torch.numel(input) * out_features
  389. return 2 * macs, macs
  390. def _relu_flops_compute(input, inplace=False):
  391. return torch.numel(input), 0
  392. def _prelu_flops_compute(input: Tensor, weight: Tensor):
  393. return torch.numel(input), 0
  394. def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False):
  395. return torch.numel(input), 0
  396. def _leaky_relu_flops_compute(input: Tensor,
  397. negative_slope: float = 0.01,
  398. inplace: bool = False):
  399. return torch.numel(input), 0
  400. def _relu6_flops_compute(input: Tensor, inplace: bool = False):
  401. return torch.numel(input), 0
  402. def _silu_flops_compute(input: Tensor, inplace: bool = False):
  403. return torch.numel(input), 0
  404. def _gelu_flops_compute(input):
  405. return torch.numel(input), 0
  406. def _pool_flops_compute(
  407. input,
  408. kernel_size,
  409. stride=None,
  410. padding=0,
  411. ceil_mode=False,
  412. count_include_pad=True,
  413. divisor_override=None,
  414. ):
  415. return torch.numel(input), 0
  416. def _conv_flops_compute(input,
  417. weight,
  418. bias=None,
  419. stride=1,
  420. padding=0,
  421. dilation=1,
  422. groups=1):
  423. assert weight.shape[1] * groups == input.shape[1]
  424. batch_size = input.shape[0]
  425. in_channels = input.shape[1]
  426. out_channels = weight.shape[0]
  427. kernel_dims = list(weight.shape[-2:])
  428. input_dims = list(input.shape[2:])
  429. length = len(input_dims)
  430. paddings = padding if type(padding) is tuple else (padding, ) * length
  431. strides = stride if type(stride) is tuple else (stride, ) * length
  432. dilations = dilation if type(dilation) is tuple else (dilation, ) * length
  433. output_dims = []
  434. for idx, input_dim in enumerate(input_dims):
  435. output_dim = (input_dim + 2 * paddings[idx] -
  436. (dilations[idx] * (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1
  437. output_dims.append(output_dim)
  438. filters_per_channel = out_channels // groups
  439. conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
  440. active_elements_count = batch_size * int(_prod(output_dims))
  441. overall_conv_macs = conv_per_position_macs * active_elements_count
  442. overall_conv_flops = 2 * overall_conv_macs
  443. bias_flops = 0
  444. if bias is not None:
  445. bias_flops = out_channels * active_elements_count
  446. return int(overall_conv_flops + bias_flops), int(overall_conv_macs)
  447. def _conv_trans_flops_compute(
  448. input,
  449. weight,
  450. bias=None,
  451. stride=1,
  452. padding=0,
  453. output_padding=0,
  454. groups=1,
  455. dilation=1,
  456. ):
  457. batch_size = input.shape[0]
  458. in_channels = input.shape[1]
  459. out_channels = weight.shape[0]
  460. kernel_dims = list(weight.shape[-2:])
  461. input_dims = list(input.shape[2:])
  462. length = len(input_dims)
  463. paddings = padding if type(padding) is tuple else (padding, ) * length
  464. strides = stride if type(stride) is tuple else (stride, ) * length
  465. dilations = dilation if type(dilation) is tuple else (dilation, ) * length
  466. output_dims = []
  467. for idx, input_dim in enumerate(input_dims):
  468. output_dim = (input_dim + 2 * paddings[idx] -
  469. (dilations[idx] * (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1
  470. output_dims.append(output_dim)
  471. paddings = padding if type(padding) is tuple else (padding, padding)
  472. strides = stride if type(stride) is tuple else (stride, stride)
  473. dilations = dilation if type(dilation) is tuple else (dilation, dilation)
  474. filters_per_channel = out_channels // groups
  475. conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
  476. active_elements_count = batch_size * int(_prod(input_dims))
  477. overall_conv_macs = conv_per_position_macs * active_elements_count
  478. overall_conv_flops = 2 * overall_conv_macs
  479. bias_flops = 0
  480. if bias is not None:
  481. bias_flops = out_channels * batch_size * int(_prod(output_dims))
  482. return int(overall_conv_flops + bias_flops), int(overall_conv_macs)
  483. def _batch_norm_flops_compute(
  484. input,
  485. running_mean,
  486. running_var,
  487. weight=None,
  488. bias=None,
  489. training=False,
  490. momentum=0.1,
  491. eps=1e-05,
  492. ):
  493. has_affine = weight is not None
  494. if training:
  495. # estimation
  496. return torch.numel(input) * (5 if has_affine else 4), 0
  497. flops = torch.numel(input) * (2 if has_affine else 1)
  498. return flops, 0
  499. def _layer_norm_flops_compute(
  500. input: Tensor,
  501. normalized_shape: List[int],
  502. weight: Optional[Tensor] = None,
  503. bias: Optional[Tensor] = None,
  504. eps: float = 1e-5,
  505. ):
  506. has_affine = weight is not None
  507. # estimation
  508. return torch.numel(input) * (5 if has_affine else 4), 0
  509. def _group_norm_flops_compute(input: Tensor,
  510. num_groups: int,
  511. weight: Optional[Tensor] = None,
  512. bias: Optional[Tensor] = None,
  513. eps: float = 1e-5):
  514. has_affine = weight is not None
  515. # estimation
  516. return torch.numel(input) * (5 if has_affine else 4), 0
  517. def _instance_norm_flops_compute(
  518. input: Tensor,
  519. running_mean: Optional[Tensor] = None,
  520. running_var: Optional[Tensor] = None,
  521. weight: Optional[Tensor] = None,
  522. bias: Optional[Tensor] = None,
  523. use_input_stats: bool = True,
  524. momentum: float = 0.1,
  525. eps: float = 1e-5,
  526. ):
  527. has_affine = weight is not None
  528. # estimation
  529. return torch.numel(input) * (5 if has_affine else 4), 0
  530. def _upsample_flops_compute(input,
  531. size=None,
  532. scale_factor=None,
  533. mode="nearest",
  534. align_corners=None):
  535. if size is not None:
  536. if isinstance(size, tuple):
  537. return int(_prod(size)), 0
  538. else:
  539. return int(size), 0
  540. assert scale_factor is not None, "either size or scale_factor should be defined"
  541. flops = torch.numel(input)
  542. if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
  543. flops * int(_prod(scale_factor))
  544. else:
  545. flops * scale_factor**len(input)
  546. return flops, 0
  547. def _softmax_flops_compute(input, dim=None, _stacklevel=3, dtype=None):
  548. return torch.numel(input), 0
  549. def _embedding_flops_compute(
  550. input,
  551. weight,
  552. padding_idx=None,
  553. max_norm=None,
  554. norm_type=2.0,
  555. scale_grad_by_freq=False,
  556. sparse=False,
  557. ):
  558. return 0, 0
  559. def _dropout_flops_compute(input, p=0.5, training=True, inplace=False):
  560. return 0, 0
  561. def _matmul_flops_compute(input, other, *, out=None):
  562. """
  563. Count flops for the matmul operation.
  564. """
  565. macs = _prod(input.shape) * other.shape[-1]
  566. return 2 * macs, macs
  567. def _addmm_flops_compute(input, mat1, mat2, *, beta=1, alpha=1, out=None):
  568. """
  569. Count flops for the addmm operation.
  570. """
  571. macs = _prod(mat1.shape) * mat2.shape[-1]
  572. return 2 * macs + _prod(input.shape), macs
  573. def _einsum_flops_compute(equation, *operands):
  574. """
  575. Count flops for the einsum operation.
  576. """
  577. equation = equation.replace(" ", "")
  578. input_shapes = [o.shape for o in operands]
  579. # Re-map equation so that same equation with different alphabet
  580. # representations will look the same.
  581. letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys()
  582. mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)}
  583. equation = equation.translate(mapping)
  584. np_arrs = [np.zeros(s) for s in input_shapes]
  585. optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
  586. for line in optim.split("\n"):
  587. print(line.lower())
  588. if "optimized flop" in line.lower():
  589. flop = int(float(line.split(":")[-1]))
  590. return flop, 0
  591. raise NotImplementedError("Unsupported einsum operation.")
  592. def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None):
  593. """
  594. Count flops for the tensor addmm operation.
  595. """
  596. macs = _prod(mat1.shape) * mat2.shape[-1]
  597. return 2 * macs + _prod(self.shape), macs
  598. def _elementwise_flops_compute(input, other, *, out=None):
  599. if not torch.is_tensor(input):
  600. if torch.is_tensor(other):
  601. return _prod(other.shape), 0
  602. else:
  603. return 1, 0
  604. elif not torch.is_tensor(other):
  605. return _prod(input.shape), 0
  606. else:
  607. dim_input = len(input.shape)
  608. dim_other = len(other.shape)
  609. max_dim = max(dim_input, dim_other)
  610. final_shape = []
  611. for i in range(max_dim):
  612. in_i = input.shape[i] if i < dim_input else 1
  613. ot_i = other.shape[i] if i < dim_other else 1
  614. if in_i > ot_i:
  615. final_shape.append(in_i)
  616. else:
  617. final_shape.append(ot_i)
  618. flops = _prod(final_shape)
  619. return flops, 0
  620. def wrapFunc(func, funcFlopCompute):
  621. oldFunc = func
  622. name = func.__name__
  623. old_functions[name] = oldFunc
  624. def newFunc(*args, **kwds):
  625. flops, macs = funcFlopCompute(*args, **kwds)
  626. if module_flop_count:
  627. module_flop_count[-1].append((name, flops))
  628. if module_mac_count and macs:
  629. module_mac_count[-1].append((name, macs))
  630. return oldFunc(*args, **kwds)
  631. newFunc.__name__ = func.__name__
  632. return newFunc
  633. def _patch_functionals():
  634. # FC
  635. F.linear = wrapFunc(F.linear, _linear_flops_compute)
  636. # convolutions
  637. F.conv1d = wrapFunc(F.conv1d, _conv_flops_compute)
  638. F.conv2d = wrapFunc(F.conv2d, _conv_flops_compute)
  639. F.conv3d = wrapFunc(F.conv3d, _conv_flops_compute)
  640. # conv transposed
  641. F.conv_transpose1d = wrapFunc(F.conv_transpose1d, _conv_trans_flops_compute)
  642. F.conv_transpose2d = wrapFunc(F.conv_transpose2d, _conv_trans_flops_compute)
  643. F.conv_transpose3d = wrapFunc(F.conv_transpose3d, _conv_trans_flops_compute)
  644. # activations
  645. F.relu = wrapFunc(F.relu, _relu_flops_compute)
  646. F.prelu = wrapFunc(F.prelu, _prelu_flops_compute)
  647. F.elu = wrapFunc(F.elu, _elu_flops_compute)
  648. F.leaky_relu = wrapFunc(F.leaky_relu, _leaky_relu_flops_compute)
  649. F.relu6 = wrapFunc(F.relu6, _relu6_flops_compute)
  650. if hasattr(F, "silu"):
  651. F.silu = wrapFunc(F.silu, _silu_flops_compute)
  652. F.gelu = wrapFunc(F.gelu, _gelu_flops_compute)
  653. # Normalizations
  654. F.batch_norm = wrapFunc(F.batch_norm, _batch_norm_flops_compute)
  655. F.layer_norm = wrapFunc(F.layer_norm, _layer_norm_flops_compute)
  656. F.instance_norm = wrapFunc(F.instance_norm, _instance_norm_flops_compute)
  657. F.group_norm = wrapFunc(F.group_norm, _group_norm_flops_compute)
  658. # poolings
  659. F.avg_pool1d = wrapFunc(F.avg_pool1d, _pool_flops_compute)
  660. F.avg_pool2d = wrapFunc(F.avg_pool2d, _pool_flops_compute)
  661. F.avg_pool3d = wrapFunc(F.avg_pool3d, _pool_flops_compute)
  662. F.max_pool1d = wrapFunc(F.max_pool1d, _pool_flops_compute)
  663. F.max_pool2d = wrapFunc(F.max_pool2d, _pool_flops_compute)
  664. F.max_pool3d = wrapFunc(F.max_pool3d, _pool_flops_compute)
  665. F.adaptive_avg_pool1d = wrapFunc(F.adaptive_avg_pool1d, _pool_flops_compute)
  666. F.adaptive_avg_pool2d = wrapFunc(F.adaptive_avg_pool2d, _pool_flops_compute)
  667. F.adaptive_avg_pool3d = wrapFunc(F.adaptive_avg_pool3d, _pool_flops_compute)
  668. F.adaptive_max_pool1d = wrapFunc(F.adaptive_max_pool1d, _pool_flops_compute)
  669. F.adaptive_max_pool2d = wrapFunc(F.adaptive_max_pool2d, _pool_flops_compute)
  670. F.adaptive_max_pool3d = wrapFunc(F.adaptive_max_pool3d, _pool_flops_compute)
  671. # upsample
  672. F.upsample = wrapFunc(F.upsample, _upsample_flops_compute)
  673. F.interpolate = wrapFunc(F.interpolate, _upsample_flops_compute)
  674. # softmax
  675. F.softmax = wrapFunc(F.softmax, _softmax_flops_compute)
  676. # embedding
  677. F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
  678. def _patch_tensor_methods():
  679. torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute)
  680. torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute)
  681. torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
  682. torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
  683. torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
  684. torch.Tensor.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
  685. torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
  686. torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute)
  687. torch.mul = wrapFunc(torch.mul, _elementwise_flops_compute)
  688. torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _elementwise_flops_compute)
  689. torch.add = wrapFunc(torch.add, _elementwise_flops_compute)
  690. torch.Tensor.add = wrapFunc(torch.Tensor.add, _elementwise_flops_compute)
  691. torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute)
  692. def _reload_functionals():
  693. # torch.nn.functional does not support importlib.reload()
  694. F.linear = old_functions[F.linear.__name__]
  695. F.conv1d = old_functions[F.conv1d.__name__]
  696. F.conv2d = old_functions[F.conv2d.__name__]
  697. F.conv3d = old_functions[F.conv3d.__name__]
  698. F.conv_transpose1d = old_functions[F.conv_transpose1d.__name__]
  699. F.conv_transpose2d = old_functions[F.conv_transpose2d.__name__]
  700. F.conv_transpose3d = old_functions[F.conv_transpose3d.__name__]
  701. F.relu = old_functions[F.relu.__name__]
  702. F.prelu = old_functions[F.prelu.__name__]
  703. F.elu = old_functions[F.elu.__name__]
  704. F.leaky_relu = old_functions[F.leaky_relu.__name__]
  705. F.relu6 = old_functions[F.relu6.__name__]
  706. F.batch_norm = old_functions[F.batch_norm.__name__]
  707. F.avg_pool1d = old_functions[F.avg_pool1d.__name__]
  708. F.avg_pool2d = old_functions[F.avg_pool2d.__name__]
  709. F.avg_pool3d = old_functions[F.avg_pool3d.__name__]
  710. F.max_pool1d = old_functions[F.max_pool1d.__name__]
  711. F.max_pool2d = old_functions[F.max_pool2d.__name__]
  712. F.max_pool3d = old_functions[F.max_pool3d.__name__]
  713. F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__name__]
  714. F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__name__]
  715. F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__name__]
  716. F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__name__]
  717. F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__name__]
  718. F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__name__]
  719. F.upsample = old_functions[F.upsample.__name__]
  720. F.interpolate = old_functions[F.interpolate.__name__]
  721. F.softmax = old_functions[F.softmax.__name__]
  722. F.embedding = old_functions[F.embedding.__name__]
  723. def _reload_tensor_methods():
  724. torch.matmul = old_functions[torch.matmul.__name__]
  725. def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
  726. # matrix matrix mult ih state and internal state
  727. flops += w_ih.shape[0] * w_ih.shape[1]
  728. # matrix matrix mult hh state and internal state
  729. flops += w_hh.shape[0] * w_hh.shape[1]
  730. if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
  731. # add both operations
  732. flops += rnn_module.hidden_size
  733. elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
  734. # hadamard of r
  735. flops += rnn_module.hidden_size
  736. # adding operations from both states
  737. flops += rnn_module.hidden_size * 3
  738. # last two hadamard _product and add
  739. flops += rnn_module.hidden_size * 3
  740. elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
  741. # adding operations from both states
  742. flops += rnn_module.hidden_size * 4
  743. # two hadamard _product and add for C state
  744. flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
  745. # final hadamard
  746. flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
  747. return flops
  748. def _rnn_forward_hook(rnn_module, input, output):
  749. flops = 0
  750. # input is a tuple containing a sequence to process and (optionally) hidden state
  751. inp = input[0]
  752. batch_size = inp.shape[0]
  753. seq_length = inp.shape[1]
  754. num_layers = rnn_module.num_layers
  755. for i in range(num_layers):
  756. w_ih = rnn_module.__getattr__("weight_ih_l" + str(i))
  757. w_hh = rnn_module.__getattr__("weight_hh_l" + str(i))
  758. if i == 0:
  759. input_size = rnn_module.input_size
  760. else:
  761. input_size = rnn_module.hidden_size
  762. flops = _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)
  763. if rnn_module.bias:
  764. b_ih = rnn_module.__getattr__("bias_ih_l" + str(i))
  765. b_hh = rnn_module.__getattr__("bias_hh_l" + str(i))
  766. flops += b_ih.shape[0] + b_hh.shape[0]
  767. flops *= batch_size
  768. flops *= seq_length
  769. if rnn_module.bidirectional:
  770. flops *= 2
  771. rnn_module.__flops__ += int(flops)
  772. def _rnn_cell_forward_hook(rnn_cell_module, input, output):
  773. flops = 0
  774. inp = input[0]
  775. batch_size = inp.shape[0]
  776. w_ih = rnn_cell_module.__getattr__("weight_ih")
  777. w_hh = rnn_cell_module.__getattr__("weight_hh")
  778. input_size = inp.shape[1]
  779. flops = _rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)
  780. if rnn_cell_module.bias:
  781. b_ih = rnn_cell_module.__getattr__("bias_ih")
  782. b_hh = rnn_cell_module.__getattr__("bias_hh")
  783. flops += b_ih.shape[0] + b_hh.shape[0]
  784. flops *= batch_size
  785. rnn_cell_module.__flops__ += int(flops)
  786. MODULE_HOOK_MAPPING = {
  787. # RNN
  788. nn.RNN: _rnn_forward_hook,
  789. nn.GRU: _rnn_forward_hook,
  790. nn.LSTM: _rnn_forward_hook,
  791. nn.RNNCell: _rnn_cell_forward_hook,
  792. nn.LSTMCell: _rnn_cell_forward_hook,
  793. nn.GRUCell: _rnn_cell_forward_hook,
  794. }
  795. def num_to_string(num, precision=2):
  796. if num // 10**9 > 0:
  797. return str(round(num / 10.0**9, precision)) + " G"
  798. elif num // 10**6 > 0:
  799. return str(round(num / 10.0**6, precision)) + " M"
  800. elif num // 10**3 > 0:
  801. return str(round(num / 10.0**3, precision)) + " K"
  802. else:
  803. return str(num)
  804. def macs_to_string(macs, units=None, precision=2):
  805. if units is None:
  806. if macs // 10**9 > 0:
  807. return str(round(macs / 10.0**9, precision)) + " GMACs"
  808. elif macs // 10**6 > 0:
  809. return str(round(macs / 10.0**6, precision)) + " MMACs"
  810. elif macs // 10**3 > 0:
  811. return str(round(macs / 10.0**3, precision)) + " KMACs"
  812. else:
  813. return str(macs) + " MACs"
  814. else:
  815. if units == "GMACs":
  816. return str(round(macs / 10.0**9, precision)) + " " + units
  817. elif units == "MMACs":
  818. return str(round(macs / 10.0**6, precision)) + " " + units
  819. elif units == "KMACs":
  820. return str(round(macs / 10.0**3, precision)) + " " + units
  821. else:
  822. return str(macs) + " MACs"
  823. def number_to_string(num, units=None, precision=2):
  824. if units is None:
  825. if num // 10**9 > 0:
  826. return str(round(num / 10.0**9, precision)) + " G"
  827. elif num // 10**6 > 0:
  828. return str(round(num / 10.0**6, precision)) + " M"
  829. elif num // 10**3 > 0:
  830. return str(round(num / 10.0**3, precision)) + " K"
  831. else:
  832. return str(num) + " "
  833. else:
  834. if units == "G":
  835. return str(round(num / 10.0**9, precision)) + " " + units
  836. elif units == "M":
  837. return str(round(num / 10.0**6, precision)) + " " + units
  838. elif units == "K":
  839. return str(round(num / 10.0**3, precision)) + " " + units
  840. else:
  841. return str(num) + " "
  842. def flops_to_string(flops, units=None, precision=2):
  843. if units is None:
  844. if flops // 10**12 > 0:
  845. return str(round(flops / 10.0**12, precision)) + " TFLOPS"
  846. if flops // 10**9 > 0:
  847. return str(round(flops / 10.0**9, precision)) + " GFLOPS"
  848. elif flops // 10**6 > 0:
  849. return str(round(flops / 10.0**6, precision)) + " MFLOPS"
  850. elif flops // 10**3 > 0:
  851. return str(round(flops / 10.0**3, precision)) + " KFLOPS"
  852. else:
  853. return str(flops) + " FLOPS"
  854. else:
  855. if units == "TFLOPS":
  856. return str(round(flops / 10.0**12, precision)) + " " + units
  857. if units == "GFLOPS":
  858. return str(round(flops / 10.0**9, precision)) + " " + units
  859. elif units == "MFLOPS":
  860. return str(round(flops / 10.0**6, precision)) + " " + units
  861. elif units == "KFLOPS":
  862. return str(round(flops / 10.0**3, precision)) + " " + units
  863. else:
  864. return str(flops) + " FLOPS"
  865. def params_to_string(params_num, units=None, precision=2):
  866. if units is None:
  867. if params_num // 10**6 > 0:
  868. return str(round(params_num / 10**6, 2)) + " M"
  869. elif params_num // 10**3:
  870. return str(round(params_num / 10**3, 2)) + " k"
  871. else:
  872. return str(params_num)
  873. else:
  874. if units == "M":
  875. return str(round(params_num / 10.0**6, precision)) + " " + units
  876. elif units == "K":
  877. return str(round(params_num / 10.0**3, precision)) + " " + units
  878. else:
  879. return str(params_num)
  880. def duration_to_string(duration, units=None, precision=2):
  881. if units is None:
  882. if duration > 1:
  883. return str(round(duration, precision)) + " s"
  884. elif duration * 10**3 > 1:
  885. return str(round(duration * 10**3, precision)) + " ms"
  886. elif duration * 10**6 > 1:
  887. return str(round(duration * 10**6, precision)) + " us"
  888. else:
  889. return str(duration)
  890. else:
  891. if units == "us":
  892. return str(round(duration * 10.0**6, precision)) + " " + units
  893. elif units == "ms":
  894. return str(round(duration * 10.0**3, precision)) + " " + units
  895. else:
  896. return str(round(duration, precision)) + " s"
  897. # can not iterate over all submodules using self.model.modules()
  898. # since modules() returns duplicate modules only once
  899. def get_module_flops(module):
  900. sum = module.__flops__
  901. # iterate over immediate children modules
  902. for child in module.children():
  903. sum += get_module_flops(child)
  904. return sum
  905. def get_module_macs(module):
  906. sum = module.__macs__
  907. # iterate over immediate children modules
  908. for child in module.children():
  909. sum += get_module_macs(child)
  910. return sum
  911. def get_module_duration(module):
  912. duration = module.__duration__
  913. if duration == 0: # e.g. ModuleList
  914. for m in module.children():
  915. duration += m.__duration__
  916. return duration
  917. def get_model_profile(
  918. model,
  919. input_res,
  920. input_constructor=None,
  921. print_profile=True,
  922. detailed=True,
  923. module_depth=-1,
  924. top_modules=1,
  925. warm_up=1,
  926. as_string=True,
  927. output_file=None,
  928. ignore_modules=None,
  929. ):
  930. """Returns the total floating-point operations, MACs, and parameters of a model.
  931. Example:
  932. .. code-block:: python
  933. model = torchvision.models.alexnet()
  934. batch_size = 256
  935. flops, macs, params = get_model_profile(model=model, input_res= (batch_size, 3, 224, 224)))
  936. Args:
  937. model ([torch.nn.Module]): the PyTorch model to be profiled.
  938. input_res (list): input shape or input to the input_constructor
  939. input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None.
  940. print_profile (bool, optional): whether to print the model profile. Defaults to True.
  941. detailed (bool, optional): whether to print the detailed model profile. Defaults to True.
  942. module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules).
  943. top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3.
  944. warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
  945. as_string (bool, optional): whether to print the output as string. Defaults to True.
  946. output_file (str, optional): path to the output file. If None, the profiler prints to stdout.
  947. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
  948. Returns:
  949. The number of floating-point operations, multiply-accumulate operations (MACs), and parameters in the model.
  950. """
  951. assert type(input_res) is tuple
  952. assert len(input_res) >= 1
  953. assert isinstance(model, nn.Module)
  954. prof = FlopsProfiler(model)
  955. model.eval()
  956. for _ in range(warm_up):
  957. if input_constructor:
  958. input = input_constructor(input_res)
  959. _ = model(**input)
  960. else:
  961. try:
  962. batch = torch.ones(()).new_empty(
  963. (*input_res,
  964. ),
  965. dtype=next(model.parameters()).dtype,
  966. device=next(model.parameters()).device,
  967. )
  968. except StopIteration:
  969. batch = torch.ones(()).new_empty((*input_res, ))
  970. _ = model(batch)
  971. prof.start_profile(ignore_list=ignore_modules)
  972. if input_constructor:
  973. input = input_constructor(input_res)
  974. _ = model(**input)
  975. else:
  976. try:
  977. batch = torch.ones(()).new_empty(
  978. (*input_res,
  979. ),
  980. dtype=next(model.parameters()).dtype,
  981. device=next(model.parameters()).device,
  982. )
  983. except StopIteration:
  984. batch = torch.ones(()).new_empty((*input_res, ))
  985. _ = model(batch)
  986. flops = prof.get_total_flops()
  987. macs = prof.get_total_macs()
  988. params = prof.get_total_params()
  989. if print_profile:
  990. prof.print_model_profile(profile_step=warm_up,
  991. module_depth=module_depth,
  992. top_modules=top_modules,
  993. detailed=detailed,
  994. output_file=output_file)
  995. prof.end_profile()
  996. if as_string:
  997. return number_to_string(flops), macs_to_string(macs), params_to_string(params)
  998. return flops, macs, params