cpu_adam.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from .builder import TorchCPUOpBuilder
  6. class CPUAdamBuilder(TorchCPUOpBuilder):
  7. BUILD_VAR = "DS_BUILD_CPU_ADAM"
  8. NAME = "cpu_adam"
  9. def __init__(self):
  10. super().__init__(name=self.NAME)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.adam.{self.NAME}_op'
  13. def sources(self):
  14. if self.build_for_cpu:
  15. return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
  16. return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp', 'csrc/common/custom_cuda_kernel.cu']
  17. def libraries_args(self):
  18. args = super().libraries_args()
  19. if self.build_for_cpu:
  20. return args
  21. if not self.is_rocm_pytorch():
  22. args += ['curand']
  23. return args
  24. def include_paths(self):
  25. import torch
  26. if self.build_for_cpu:
  27. CUDA_INCLUDE = []
  28. elif not self.is_rocm_pytorch():
  29. CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
  30. else:
  31. CUDA_INCLUDE = []
  32. return ['csrc/includes'] + CUDA_INCLUDE