random_ltd.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. """
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. """
  4. from .builder import CUDAOpBuilder
  5. class RandomLTDBuilder(CUDAOpBuilder):
  6. BUILD_VAR = "DS_BUILD_RANDOM_LTD"
  7. NAME = "random_ltd"
  8. def __init__(self, name=None):
  9. name = self.NAME if name is None else name
  10. super().__init__(name=name)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.{self.NAME}_op'
  13. def extra_ldflags(self):
  14. if not self.is_rocm_pytorch():
  15. return ['-lcurand']
  16. else:
  17. return []
  18. def sources(self):
  19. return [
  20. 'csrc/random_ltd/pt_binding.cpp',
  21. 'csrc/random_ltd/gather_scatter.cu',
  22. 'csrc/random_ltd/slice_attn_masks.cu',
  23. 'csrc/random_ltd/token_sort.cu'
  24. ]
  25. def include_paths(self):
  26. includes = ['csrc/includes']
  27. if self.is_rocm_pytorch():
  28. from torch.utils.cpp_extension import ROCM_HOME
  29. includes += [
  30. '{}/hiprand/include'.format(ROCM_HOME),
  31. '{}/rocrand/include'.format(ROCM_HOME)
  32. ]
  33. return includes