gds.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from .async_io import AsyncIOBuilder
  6. class GDSBuilder(AsyncIOBuilder):
  7. BUILD_VAR = "DS_BUILD_GDS"
  8. NAME = "gds"
  9. def __init__(self):
  10. super().__init__()
  11. def absolute_name(self):
  12. return f'deepspeed.ops.gds.{self.NAME}_op'
  13. def lib_sources(self):
  14. src_list = ['csrc/gds/py_lib/deepspeed_py_gds_handle.cpp', 'csrc/gds/py_lib/deepspeed_gds_op.cpp']
  15. return super().lib_sources() + src_list
  16. def sources(self):
  17. return self.lib_sources() + ['csrc/gds/py_lib/py_ds_gds.cpp']
  18. def cxx_args(self):
  19. return super().cxx_args() + ['-lcufile']
  20. def include_paths(self):
  21. import torch
  22. CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
  23. return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE
  24. def extra_ldflags(self):
  25. return super().extra_ldflags() + ['-lcufile']
  26. def is_compatible(self, verbose=False):
  27. if self.is_rocm_pytorch():
  28. if verbose:
  29. self.warning(f'{self.NAME} is not compatible with ROCM')
  30. return False
  31. try:
  32. import torch.utils.cpp_extension
  33. except ImportError:
  34. if verbose:
  35. self.warning("Please install torch if trying to pre-compile GDS")
  36. return False
  37. CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME
  38. if CUDA_HOME is None:
  39. if verbose:
  40. self.warning("Please install torch CUDA if trying to pre-compile GDS with CUDA")
  41. return False
  42. CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64")
  43. gds_compatible = self.has_function(funcname="cuFileDriverOpen",
  44. libraries=("cufile", ),
  45. library_dirs=(
  46. CUDA_HOME,
  47. CUDA_LIB64,
  48. ),
  49. verbose=verbose)
  50. return gds_compatible and super().is_compatible(verbose)