env_report.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import deepspeed
  7. import subprocess
  8. import argparse
  9. from .ops.op_builder.all_ops import ALL_OPS
  10. from .git_version_info import installed_ops, torch_info
  11. from deepspeed.accelerator import get_accelerator
  12. GREEN = '\033[92m'
  13. RED = '\033[91m'
  14. YELLOW = '\033[93m'
  15. END = '\033[0m'
  16. SUCCESS = f"{GREEN} [SUCCESS] {END}"
  17. OKAY = f"{GREEN}[OKAY]{END}"
  18. WARNING = f"{YELLOW}[WARNING]{END}"
  19. FAIL = f'{RED}[FAIL]{END}'
  20. INFO = '[INFO]'
  21. color_len = len(GREEN) + len(END)
  22. okay = f"{GREEN}[OKAY]{END}"
  23. warning = f"{YELLOW}[WARNING]{END}"
  24. def op_report(verbose=True):
  25. max_dots = 23
  26. max_dots2 = 11
  27. h = ["op name", "installed", "compatible"]
  28. print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
  29. print("DeepSpeed C++/CUDA extension op report")
  30. print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
  31. print("NOTE: Ops not installed will be just-in-time (JIT) compiled at\n"
  32. " runtime if needed. Op compatibility means that your system\n"
  33. " meet the required dependencies to JIT install the op.")
  34. print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
  35. print("JIT compiled ops requires ninja")
  36. ninja_status = OKAY if ninja_installed() else FAIL
  37. print('ninja', "." * (max_dots - 5), ninja_status)
  38. print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
  39. print(h[0], "." * (max_dots - len(h[0])), h[1], "." * (max_dots2 - len(h[1])), h[2])
  40. print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
  41. installed = f"{GREEN}[YES]{END}"
  42. no = f"{YELLOW}[NO]{END}"
  43. for op_name, builder in ALL_OPS.items():
  44. dots = "." * (max_dots - len(op_name))
  45. is_compatible = OKAY if builder.is_compatible(verbose) else no
  46. is_installed = installed if installed_ops.get(op_name, False) else no
  47. dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
  48. print(op_name, dots, is_installed, dots2, is_compatible)
  49. print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
  50. def ninja_installed():
  51. try:
  52. import ninja # noqa: F401 # type: ignore
  53. except ImportError:
  54. return False
  55. return True
  56. def nvcc_version():
  57. import torch.utils.cpp_extension
  58. cuda_home = torch.utils.cpp_extension.CUDA_HOME
  59. if cuda_home is None:
  60. return f"{RED} [FAIL] cannot find CUDA_HOME via torch.utils.cpp_extension.CUDA_HOME={torch.utils.cpp_extension.CUDA_HOME} {END}"
  61. try:
  62. output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
  63. except FileNotFoundError:
  64. return f"{RED} [FAIL] nvcc missing {END}"
  65. output_split = output.split()
  66. release_idx = output_split.index("release")
  67. release = output_split[release_idx + 1].replace(',', '').split(".")
  68. return ".".join(release)
  69. def get_shm_size():
  70. try:
  71. shm_stats = os.statvfs('/dev/shm')
  72. except (OSError, FileNotFoundError, ValueError):
  73. return "UNKNOWN", None
  74. shm_size = shm_stats.f_frsize * shm_stats.f_blocks
  75. shm_hbytes = human_readable_size(shm_size)
  76. warn = []
  77. if shm_size < 512 * 1024**2:
  78. warn.append(
  79. f" {YELLOW} [WARNING] /dev/shm size might be too small, if running in docker increase to at least --shm-size='1gb' {END}"
  80. )
  81. if get_accelerator().communication_backend_name() == "nccl":
  82. warn.append(
  83. f" {YELLOW} [WARNING] see more details about NCCL requirements: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#sharing-data {END}"
  84. )
  85. return shm_hbytes, warn
  86. def human_readable_size(size):
  87. units = ['B', 'KB', 'MB', 'GB', 'TB']
  88. i = 0
  89. while size >= 1024 and i < len(units) - 1:
  90. size /= 1024
  91. i += 1
  92. return f'{size:.2f} {units[i]}'
  93. def debug_report():
  94. max_dots = 33
  95. report = [("torch install path", torch.__path__), ("torch version", torch.__version__),
  96. ("deepspeed install path", deepspeed.__path__),
  97. ("deepspeed info", f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}")]
  98. if get_accelerator().device_name() == 'cuda':
  99. hip_version = getattr(torch.version, "hip", None)
  100. report.extend([("torch cuda version", torch.version.cuda), ("torch hip version", hip_version),
  101. ("nvcc version", (None if hip_version else nvcc_version())),
  102. ("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " +
  103. (f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}"))
  104. ])
  105. else:
  106. report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")])
  107. report.append(("shared memory (/dev/shm) size", get_shm_size()))
  108. print("DeepSpeed general environment info:")
  109. for name, value in report:
  110. warns = []
  111. if isinstance(value, tuple):
  112. value, warns = value
  113. print(name, "." * (max_dots - len(name)), value)
  114. if warns:
  115. for warn in warns:
  116. print(warn)
  117. def parse_arguments():
  118. parser = argparse.ArgumentParser()
  119. parser.add_argument('--hide_operator_status',
  120. action='store_true',
  121. help='Suppress display of installation and compatibility statuses of DeepSpeed operators. ')
  122. parser.add_argument('--hide_errors_and_warnings', action='store_true', help='Suppress warning and error messages.')
  123. args = parser.parse_args()
  124. return args
  125. def main(hide_operator_status=False, hide_errors_and_warnings=False):
  126. if not hide_operator_status:
  127. op_report(verbose=not hide_errors_and_warnings)
  128. debug_report()
  129. def cli_main():
  130. args = parse_arguments()
  131. main(hide_operator_status=args.hide_operator_status, hide_errors_and_warnings=args.hide_errors_and_warnings)
  132. if __name__ == "__main__":
  133. main()