cpu_adam.py 798 B

123456789101112131415161718192021222324252627282930313233
  1. # Copyright (c) 2023 Habana Labs, Ltd. an Intel Company
  2. # Copyright (c) Microsoft Corporation.
  3. # SPDX-License-Identifier: Apache-2.0
  4. # DeepSpeed Team
  5. from .builder import CPUOpBuilder
  6. class CPUAdamBuilder(CPUOpBuilder):
  7. BUILD_VAR = "DS_BUILD_CPU_ADAM"
  8. NAME = "cpu_adam"
  9. def __init__(self):
  10. super().__init__(name=self.NAME)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.adam.{self.NAME}_op'
  13. def sources(self):
  14. return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
  15. def cxx_args(self):
  16. args = super().cxx_args()
  17. args += ['-DENABLE_BFLOAT16']
  18. return args
  19. def libraries_args(self):
  20. args = super().libraries_args()
  21. return args
  22. def include_paths(self):
  23. return ['csrc/includes']