check-torchcuda.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #!/usr/bin/env python3
  2. # Copyright (c) Microsoft Corporation.
  3. # SPDX-License-Identifier: Apache-2.0
  4. # DeepSpeed Team
  5. from __future__ import annotations
  6. '''Copyright The Microsoft DeepSpeed Team'''
  7. """
  8. Checks each file in sys.argv for the string "torch.cuda".
  9. Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py
  10. """
  11. import subprocess
  12. import sys
  13. def err(s: str) -> None:
  14. print(s, file=sys.stderr)
  15. print(*sys.argv[1:])
  16. # There are many ways we could search for the string "torch.cuda", but `git
  17. # grep --no-index` is nice because
  18. # - it's very fast (as compared to iterating over the file in Python)
  19. # - we can reasonably assume it's available on all machines
  20. # - unlike plain grep, which is slower and has different flags on MacOS versus
  21. # Linux, git grep is always the same.
  22. res = subprocess.run(
  23. ["git", "grep", "-Hn", "--no-index", "-e", r"torch\.cuda", "--and", "--not", "-e", "#ignore-cuda", *sys.argv[1:]],
  24. capture_output=True,
  25. )
  26. if res.returncode == 0:
  27. err('Error: The string "torch.cuda" was found.\nPlease replace all calls to torch.cuda with "get_accelerator()" and add the following import line:\n\n from deepspeed.accelerator import get_accelerator\n\nIf your code is mean to be cuda specific, please add the following comment in the line with torch.cuda:\n\n #ignore-cuda\n'
  28. )
  29. err(res.stdout.decode("utf-8"))
  30. sys.exit(1)
  31. elif res.returncode == 2:
  32. err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
  33. err(res.stderr.decode("utf-8"))
  34. sys.exit(2)
  35. res = subprocess.run(
  36. ["git", "grep", "-Hn", "--no-index", r"\.cuda()", *sys.argv[1:]],
  37. capture_output=True,
  38. )
  39. if res.returncode == 0:
  40. err('Error: The string ".cuda()" was found. This implies convert a tensor to cuda tensor. Please replace all calls to tensor.cuda() with "tensor.to(get_accelerator().device_name())" and add the following import line:\nfrom deepspeed.accelerator import get_accelerator'
  41. )
  42. err(res.stdout.decode("utf-8"))
  43. sys.exit(1)
  44. elif res.returncode == 2:
  45. err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
  46. err(res.stderr.decode("utf-8"))
  47. sys.exit(2)
  48. files = []
  49. for file in sys.argv[1:]:
  50. if not file.endswith(".cpp"):
  51. files.append(file)
  52. res = subprocess.run(
  53. ["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
  54. capture_output=True,
  55. )
  56. if res.returncode == 0:
  57. err('''
  58. Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor.
  59. Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)",
  60. and add the following import line:
  61. 'from deepspeed.accelerator import get_accelerator'
  62. ''')
  63. err(res.stdout.decode("utf-8"))
  64. sys.exit(1)
  65. elif res.returncode == 2:
  66. err(f"Error invoking grep on {', '.join(files)}:")
  67. err(res.stderr.decode("utf-8"))
  68. sys.exit(2)