cpu_lion.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from .builder import TorchCPUOpBuilder
  6. class CPULionBuilder(TorchCPUOpBuilder):
  7. BUILD_VAR = "DS_BUILD_CPU_LION"
  8. NAME = "cpu_lion"
  9. def __init__(self):
  10. super().__init__(name=self.NAME)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.lion.{self.NAME}_op'
  13. def sources(self):
  14. if self.build_for_cpu:
  15. return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
  16. return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu']
  17. def libraries_args(self):
  18. args = super().libraries_args()
  19. if self.build_for_cpu:
  20. return args
  21. if not self.is_rocm_pytorch():
  22. args += ['curand']
  23. return args
  24. def include_paths(self):
  25. import torch
  26. if self.build_for_cpu:
  27. CUDA_INCLUDE = []
  28. elif not self.is_rocm_pytorch():
  29. CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
  30. else:
  31. CUDA_INCLUDE = [
  32. os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
  33. os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
  34. os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
  35. ]
  36. return ['csrc/includes'] + CUDA_INCLUDE