flatten_bench.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python
  2. # run the benchmark under timeit (-t), cProfile (-c), line_profiler (-l)
  3. #
  4. # usage:
  5. # ./flatten_bench.py -t
  6. # ./flatten_bench.py -c
  7. # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
  8. import argparse
  9. import gc
  10. import torch
  11. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  12. from deepspeed.ops.op_builder import UtilsBuilder
  13. from apex_C import flatten as flatten_apex
  14. util_ops = UtilsBuilder().load()
  15. flatten = util_ops.flatten
  16. unflatten = util_ops.unflatten
  17. torch.manual_seed(0)
  18. # emulate a small typical model weights
  19. x = [
  20. torch.rand((512,
  21. 512)).cuda(),
  22. torch.rand((512,
  23. 1024)).cuda(),
  24. torch.rand((512,
  25. 30000)).cuda()
  26. ]
  27. t = x * 30
  28. # warm up and check that the same output is produced
  29. flat_py = _flatten_dense_tensors(t)
  30. flat_cpp = flatten(t)
  31. flat_apex = flatten_apex(t)
  32. #numel = flat_cpp.numel()
  33. assert torch.eq(flat_py, flat_cpp).all(), "both produce the same tensor"
  34. assert torch.eq(flat_py, flat_apex).all(), "both produce the same tensor"
  35. TIMES = 1000
  36. # the programs being tested
  37. def py():
  38. for i in range(TIMES):
  39. flat = _flatten_dense_tensors(t)
  40. def cpp():
  41. for i in range(TIMES):
  42. flat = flatten(t)
  43. def apex():
  44. for i in range(TIMES):
  45. flat = flatten_apex(t)
  46. #### cProfile ####
  47. import cProfile
  48. def cprofileme():
  49. print("--------------- cProfile -----------------")
  50. print("py")
  51. cProfile.run("py()", sort=-1)
  52. gc.collect()
  53. torch.cuda.empty_cache()
  54. print("cpp")
  55. cProfile.run("cpp()", sort=-1)
  56. gc.collect()
  57. torch.cuda.empty_cache()
  58. print("apex")
  59. cProfile.run("apex()", sort=-1)
  60. gc.collect()
  61. torch.cuda.empty_cache()
  62. #### timeit ####
  63. import timeit
  64. def timeme():
  65. print("--------------- timeit -----------------")
  66. print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}')
  67. gc.collect()
  68. torch.cuda.empty_cache()
  69. print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}')
  70. gc.collect()
  71. torch.cuda.empty_cache()
  72. print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}')
  73. gc.collect()
  74. torch.cuda.empty_cache()
  75. #### line_profiler ####
  76. # this one requires a special way to be called
  77. # pip install line_profiler
  78. # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
  79. def line_profileme():
  80. print("--------------- line_profier -----------------")
  81. print("py")
  82. profile(py)()
  83. gc.collect()
  84. torch.cuda.empty_cache()
  85. print("cpp")
  86. profile(cpp)()
  87. gc.collect()
  88. torch.cuda.empty_cache()
  89. print("apex")
  90. profile(apex)()
  91. gc.collect()
  92. torch.cuda.empty_cache()
  93. if __name__ == "__main__":
  94. parser = argparse.ArgumentParser()
  95. parser.add_argument("-l", action='store_true')
  96. parser.add_argument("-c", action='store_true')
  97. parser.add_argument("-t", action='store_true')
  98. args = parser.parse_args()
  99. if args.l:
  100. line_profileme()
  101. elif args.c:
  102. cprofileme()
  103. elif args.t:
  104. timeme()