check-torchdist.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. '''Copyright The Microsoft DeepSpeed Team'''
  4. """
  5. Checks each file in sys.argv for the string "torch.distributed".
  6. Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py
  7. """
  8. import subprocess
  9. import sys
  10. def err(s: str) -> None:
  11. print(s, file=sys.stderr)
  12. # There are many ways we could search for the string "torch.distributed", but `git
  13. # grep --no-index` is nice because
  14. # - it's very fast (as compared to iterating over the file in Python)
  15. # - we can reasonably assume it's available on all machines
  16. # - unlike plain grep, which is slower and has different flags on MacOS versus
  17. # Linux, git grep is always the same.
  18. res = subprocess.run(
  19. ["git",
  20. "grep",
  21. "-Hn",
  22. "--no-index",
  23. r"torch\.distributed",
  24. *sys.argv[1:]],
  25. capture_output=True,
  26. )
  27. if res.returncode == 0:
  28. err('Error: The string "torch.distributed" was found. Please replace all calls to torch.distributed with "deepspeed.comm"'
  29. )
  30. err(res.stdout.decode("utf-8"))
  31. sys.exit(1)
  32. elif res.returncode == 2:
  33. err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
  34. err(res.stderr.decode("utf-8"))
  35. sys.exit(2)