nvtx.py 499 B

12345678910111213141516171819
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.accelerator import get_accelerator
  5. def instrument_w_nvtx(func):
  6. """decorator that causes an NVTX range to be recorded for the duration of the
  7. function call."""
  8. def wrapped_fn(*args, **kwargs):
  9. get_accelerator().range_push(func.__qualname__)
  10. ret_val = func(*args, **kwargs)
  11. get_accelerator().range_pop()
  12. return ret_val
  13. return wrapped_fn