# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team #!/usr/bin/env python # run the benchmark under timeit (-t), cProfile (-c), line_profiler (-l) # # usage: # ./flatten_bench.py -t # ./flatten_bench.py -c # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof import argparse import gc import torch from torch._utils import _flatten_dense_tensors from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import UtilsBuilder from apex_C import flatten as flatten_apex util_ops = UtilsBuilder().load() flatten = util_ops.flatten unflatten = util_ops.unflatten torch.manual_seed(0) # emulate a small typical model weights x = [ torch.rand((512, 512)).to(get_accelerator().device_name()), torch.rand((512, 1024)).to(get_accelerator().device_name()), torch.rand((512, 30000)).to(get_accelerator().device_name()) ] t = x * 30 # warm up and check that the same output is produced flat_py = _flatten_dense_tensors(t) flat_cpp = flatten(t) flat_apex = flatten_apex(t) #numel = flat_cpp.numel() assert torch.eq(flat_py, flat_cpp).all(), "both produce the same tensor" assert torch.eq(flat_py, flat_apex).all(), "both produce the same tensor" TIMES = 1000 # the programs being tested def py(): for i in range(TIMES): flat = _flatten_dense_tensors(t) def cpp(): for i in range(TIMES): flat = flatten(t) def apex(): for i in range(TIMES): flat = flatten_apex(t) #### cProfile #### import cProfile def cprofileme(): print("--------------- cProfile -----------------") print("py") cProfile.run("py()", sort=-1) gc.collect() get_accelerator().empty_cache() print("cpp") cProfile.run("cpp()", sort=-1) gc.collect() get_accelerator().empty_cache() print("apex") cProfile.run("apex()", sort=-1) gc.collect() get_accelerator().empty_cache() #### timeit #### import timeit def timeme(): print("--------------- timeit -----------------") print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}') gc.collect() get_accelerator().empty_cache() print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}') gc.collect() get_accelerator().empty_cache() print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}') gc.collect() get_accelerator().empty_cache() #### line_profiler #### # this one requires a special way to be called # pip install line_profiler # kernprof -l flatten_bench.py -l; python -m line_profiler flatten_bench.py.lprof def line_profileme(): print("--------------- line_profiler -----------------") print("py") profile(py)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() print("cpp") profile(cpp)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() print("apex") profile(apex)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-l", action='store_true') parser.add_argument("-c", action='store_true') parser.add_argument("-t", action='store_true') args = parser.parse_args() if args.l: line_profileme() elif args.c: cprofileme() elif args.t: timeme()