nvtx.py 448 B

123456789101112131415
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from deepspeed.accelerator import get_accelerator
  3. def instrument_w_nvtx(func):
  4. """decorator that causes an NVTX range to be recorded for the duration of the
  5. function call."""
  6. def wrapped_fn(*args, **kwargs):
  7. get_accelerator().range_push(func.__qualname__)
  8. ret_val = func(*args, **kwargs)
  9. get_accelerator().range_pop()
  10. return ret_val
  11. return wrapped_fn