random_ltd.py 879 B

12345678910111213141516171819202122232425262728293031323334
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import CUDAOpBuilder
  5. class RandomLTDBuilder(CUDAOpBuilder):
  6. BUILD_VAR = "DS_BUILD_RANDOM_LTD"
  7. NAME = "random_ltd"
  8. def __init__(self, name=None):
  9. name = self.NAME if name is None else name
  10. super().__init__(name=name)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.{self.NAME}_op'
  13. def extra_ldflags(self):
  14. if not self.is_rocm_pytorch():
  15. return ['-lcurand']
  16. else:
  17. return []
  18. def sources(self):
  19. return [
  20. 'csrc/random_ltd/pt_binding.cpp', 'csrc/random_ltd/gather_scatter.cu',
  21. 'csrc/random_ltd/slice_attn_masks.cu', 'csrc/random_ltd/token_sort.cu'
  22. ]
  23. def include_paths(self):
  24. includes = ['csrc/includes']
  25. return includes