nvtx.py 575 B

1234567891011121314151617181920212223
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.accelerator import get_accelerator
  5. enable_nvtx = True
  6. def instrument_w_nvtx(func):
  7. """decorator that causes an NVTX range to be recorded for the duration of the
  8. function call."""
  9. def wrapped_fn(*args, **kwargs):
  10. if enable_nvtx:
  11. get_accelerator().range_push(func.__qualname__)
  12. ret_val = func(*args, **kwargs)
  13. if enable_nvtx:
  14. get_accelerator().range_pop()
  15. return ret_val
  16. return wrapped_fn