flatten_bench.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. #!/usr/bin/env python
  5. # run the benchmark under timeit (-t), cProfile (-c), line_profiler (-l)
  6. #
  7. # usage:
  8. # ./flatten_bench.py -t
  9. # ./flatten_bench.py -c
  10. # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
  11. import argparse
  12. import gc
  13. import torch
  14. from torch._utils import _flatten_dense_tensors
  15. from deepspeed.accelerator import get_accelerator
  16. from deepspeed.ops.op_builder import UtilsBuilder
  17. from apex_C import flatten as flatten_apex
  18. util_ops = UtilsBuilder().load()
  19. flatten = util_ops.flatten
  20. unflatten = util_ops.unflatten
  21. torch.manual_seed(0)
  22. # emulate a small typical model weights
  23. x = [
  24. torch.rand((512, 512)).to(get_accelerator().device_name()),
  25. torch.rand((512, 1024)).to(get_accelerator().device_name()),
  26. torch.rand((512, 30000)).to(get_accelerator().device_name())
  27. ]
  28. t = x * 30
  29. # warm up and check that the same output is produced
  30. flat_py = _flatten_dense_tensors(t)
  31. flat_cpp = flatten(t)
  32. flat_apex = flatten_apex(t)
  33. #numel = flat_cpp.numel()
  34. assert torch.eq(flat_py, flat_cpp).all(), "both produce the same tensor"
  35. assert torch.eq(flat_py, flat_apex).all(), "both produce the same tensor"
  36. TIMES = 1000
  37. # the programs being tested
  38. def py():
  39. for i in range(TIMES):
  40. flat = _flatten_dense_tensors(t)
  41. def cpp():
  42. for i in range(TIMES):
  43. flat = flatten(t)
  44. def apex():
  45. for i in range(TIMES):
  46. flat = flatten_apex(t)
  47. #### cProfile ####
  48. import cProfile
  49. def cprofileme():
  50. print("--------------- cProfile -----------------")
  51. print("py")
  52. cProfile.run("py()", sort=-1)
  53. gc.collect()
  54. get_accelerator().empty_cache()
  55. print("cpp")
  56. cProfile.run("cpp()", sort=-1)
  57. gc.collect()
  58. get_accelerator().empty_cache()
  59. print("apex")
  60. cProfile.run("apex()", sort=-1)
  61. gc.collect()
  62. get_accelerator().empty_cache()
  63. #### timeit ####
  64. import timeit
  65. def timeme():
  66. print("--------------- timeit -----------------")
  67. print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}')
  68. gc.collect()
  69. get_accelerator().empty_cache()
  70. print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}')
  71. gc.collect()
  72. get_accelerator().empty_cache()
  73. print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}')
  74. gc.collect()
  75. get_accelerator().empty_cache()
  76. #### line_profiler ####
  77. # this one requires a special way to be called
  78. # pip install line_profiler
  79. # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
  80. def line_profileme():
  81. print("--------------- line_profiler -----------------")
  82. print("py")
  83. profile(py)() # noqa: F821 # type: ignore
  84. gc.collect()
  85. get_accelerator().empty_cache()
  86. print("cpp")
  87. profile(cpp)() # noqa: F821 # type: ignore
  88. gc.collect()
  89. get_accelerator().empty_cache()
  90. print("apex")
  91. profile(apex)() # noqa: F821 # type: ignore
  92. gc.collect()
  93. get_accelerator().empty_cache()
  94. if __name__ == "__main__":
  95. parser = argparse.ArgumentParser()
  96. parser.add_argument("-l", action='store_true')
  97. parser.add_argument("-c", action='store_true')
  98. parser.add_argument("-t", action='store_true')
  99. args = parser.parse_args()
  100. if args.l:
  101. line_profileme()
  102. elif args.c:
  103. cprofileme()
  104. elif args.t:
  105. timeme()