123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- #!/usr/bin/env python
- # run the benchmark under timeit (-t), cProfile (-c), line_profiler (-l)
- #
- # usage:
- # ./flatten_bench.py -t
- # ./flatten_bench.py -c
- # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
- import argparse
- import gc
- import torch
- from torch._utils import _flatten_dense_tensors
- from deepspeed.accelerator import get_accelerator
- from deepspeed.ops.op_builder import UtilsBuilder
- from apex_C import flatten as flatten_apex
- util_ops = UtilsBuilder().load()
- flatten = util_ops.flatten
- unflatten = util_ops.unflatten
- torch.manual_seed(0)
- # emulate a small typical model weights
- x = [
- torch.rand((512, 512)).to(get_accelerator().device_name()),
- torch.rand((512, 1024)).to(get_accelerator().device_name()),
- torch.rand((512, 30000)).to(get_accelerator().device_name())
- ]
- t = x * 30
- # warm up and check that the same output is produced
- flat_py = _flatten_dense_tensors(t)
- flat_cpp = flatten(t)
- flat_apex = flatten_apex(t)
- #numel = flat_cpp.numel()
- assert torch.eq(flat_py, flat_cpp).all(), "both produce the same tensor"
- assert torch.eq(flat_py, flat_apex).all(), "both produce the same tensor"
- TIMES = 1000
- # the programs being tested
- def py():
- for i in range(TIMES):
- flat = _flatten_dense_tensors(t)
- def cpp():
- for i in range(TIMES):
- flat = flatten(t)
- def apex():
- for i in range(TIMES):
- flat = flatten_apex(t)
- #### cProfile ####
- import cProfile
- def cprofileme():
- print("--------------- cProfile -----------------")
- print("py")
- cProfile.run("py()", sort=-1)
- gc.collect()
- get_accelerator().empty_cache()
- print("cpp")
- cProfile.run("cpp()", sort=-1)
- gc.collect()
- get_accelerator().empty_cache()
- print("apex")
- cProfile.run("apex()", sort=-1)
- gc.collect()
- get_accelerator().empty_cache()
- #### timeit ####
- import timeit
- def timeme():
- print("--------------- timeit -----------------")
- print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}')
- gc.collect()
- get_accelerator().empty_cache()
- print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}')
- gc.collect()
- get_accelerator().empty_cache()
- print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}')
- gc.collect()
- get_accelerator().empty_cache()
- #### line_profiler ####
- # this one requires a special way to be called
- # pip install line_profiler
- # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof
- def line_profileme():
- print("--------------- line_profiler -----------------")
- print("py")
- profile(py)() # noqa: F821 # type: ignore
- gc.collect()
- get_accelerator().empty_cache()
- print("cpp")
- profile(cpp)() # noqa: F821 # type: ignore
- gc.collect()
- get_accelerator().empty_cache()
- print("apex")
- profile(apex)() # noqa: F821 # type: ignore
- gc.collect()
- get_accelerator().empty_cache()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("-l", action='store_true')
- parser.add_argument("-c", action='store_true')
- parser.add_argument("-t", action='store_true')
- args = parser.parse_args()
- if args.l:
- line_profileme()
- elif args.c:
- cprofileme()
- elif args.t:
- timeme()
|