unflatten_bench.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. #!/usr/bin/env python
  5. # run the benchmark under timeit (-t), cProfile (-c), line_profiler (-l)
  6. #
  7. # usage:
  8. # ./unflatten_bench.py -t
  9. # ./unflatten_bench.py -c
  10. # kernprof -l unflatten_bench.py -l; python -m line_profiler unflatten_bench.py.lprof
  11. import argparse
  12. import gc
  13. import torch
  14. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  15. from deepspeed.accelerator import get_accelerator
  16. from deepspeed.ops.op_builder import UtilsBuilder
  17. from apex_C import flatten as flatten_apex
  18. from apex_C import unflatten as unflatten_apex
  19. util_ops = UtilsBuilder().load()
  20. flatten = util_ops.flatten
  21. unflatten = util_ops.unflatten
  22. torch.manual_seed(0)
  23. # emulate a small typical model weights
  24. x = [
  25. torch.rand((512, 512)).to(get_accelerator().device_name()),
  26. torch.rand((512, 1024)).to(get_accelerator().device_name()),
  27. torch.rand((512, 30000)).to(get_accelerator().device_name())
  28. ]
  29. unflat_t = x * 30
  30. # warm up and check that the same output is produced
  31. flat_py = _flatten_dense_tensors(unflat_t)
  32. flat_cpp = flatten(unflat_t)
  33. flat_apex = flatten_apex(unflat_t)
  34. #numel = flat_cpp.numel()
  35. assert torch.eq(flat_py, flat_cpp).all(), "both produce the same tensor"
  36. assert torch.eq(flat_py, flat_apex).all(), "both produce the same tensor"
  37. flat_t = flat_py
  38. unflat_py = _unflatten_dense_tensors(flat_py, unflat_t)
  39. for i in range(len(unflat_t)):
  40. assert torch.eq(unflat_t[i], unflat_py[i]).all()
  41. unflat_cpp = _unflatten_dense_tensors(flat_cpp, unflat_t)
  42. for i in range(len(unflat_t)):
  43. assert torch.eq(unflat_t[i], unflat_cpp[i]).all()
  44. unflat_apex = _unflatten_dense_tensors(flat_apex, unflat_t)
  45. for i in range(len(unflat_t)):
  46. assert torch.eq(unflat_t[i], unflat_apex[i]).all()
  47. # the programs being tested
  48. def py():
  49. for i in range(1000):
  50. unflat = _unflatten_dense_tensors(flat_t, unflat_t)
  51. def cpp():
  52. for i in range(1000):
  53. unflat = unflatten(flat_t, unflat_t)
  54. def apex():
  55. for i in range(1000):
  56. unflat = unflatten_apex(flat_t, unflat_t)
  57. #### cProfile ####
  58. import cProfile
  59. def cprofileme():
  60. print("--------------- cProfile -----------------")
  61. print("py")
  62. cProfile.run("py()", sort=-1)
  63. gc.collect()
  64. get_accelerator().empty_cache()
  65. print("cpp")
  66. cProfile.run("cpp()", sort=-1)
  67. gc.collect()
  68. get_accelerator().empty_cache()
  69. print("apex")
  70. cProfile.run("apex()", sort=-1)
  71. gc.collect()
  72. get_accelerator().empty_cache()
  73. #### timeit ####
  74. import timeit
  75. def timeme():
  76. print("--------------- timeit -----------------")
  77. print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}')
  78. gc.collect()
  79. get_accelerator().empty_cache()
  80. print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}')
  81. gc.collect()
  82. get_accelerator().empty_cache()
  83. print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}')
  84. gc.collect()
  85. get_accelerator().empty_cache()
  86. #### line_profiler ####
  87. # this one requires a special way to be called
  88. # pip install line_profiler
  89. # kernprof -l unflatten_bench.py -l; python -m line_profiler unflatten_bench.py.lprof
  90. def line_profileme():
  91. print("--------------- line_profier -----------------")
  92. print("py")
  93. profile(py)() # noqa: F821 # type: ignore
  94. gc.collect()
  95. get_accelerator().empty_cache()
  96. print("cpp")
  97. profile(cpp)() # noqa: F821 # type: ignore
  98. gc.collect()
  99. get_accelerator().empty_cache()
  100. print("apex")
  101. profile(apex)() # noqa: F821 # type: ignore
  102. gc.collect()
  103. get_accelerator().empty_cache()
  104. if __name__ == "__main__":
  105. parser = argparse.ArgumentParser()
  106. parser.add_argument("-l", action='store_true')
  107. parser.add_argument("-c", action='store_true')
  108. parser.add_argument("-t", action='store_true')
  109. args = parser.parse_args()
  110. if args.l:
  111. line_profileme()
  112. elif args.c:
  113. cprofileme()
  114. elif args.t:
  115. timeme()