cpu_adagrad.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from .builder import TorchCPUOpBuilder
  6. class CPUAdagradBuilder(TorchCPUOpBuilder):
  7. BUILD_VAR = "DS_BUILD_CPU_ADAGRAD"
  8. NAME = "cpu_adagrad"
  9. def __init__(self):
  10. super().__init__(name=self.NAME)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.adagrad.{self.NAME}_op'
  13. def sources(self):
  14. if self.build_for_cpu:
  15. return ['csrc/adagrad/cpu_adagrad.cpp']
  16. return ['csrc/adagrad/cpu_adagrad.cpp', 'csrc/common/custom_cuda_kernel.cu']
  17. def libraries_args(self):
  18. args = super().libraries_args()
  19. if self.build_for_cpu:
  20. return args
  21. if not self.is_rocm_pytorch():
  22. args += ['curand']
  23. return args
  24. def include_paths(self):
  25. import torch
  26. if self.build_for_cpu:
  27. CUDA_INCLUDE = []
  28. elif not self.is_rocm_pytorch():
  29. CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
  30. else:
  31. CUDA_INCLUDE = []
  32. return ['csrc/includes'] + CUDA_INCLUDE