ds_bench 757 B

1234567891011121314151617
  1. #!/usr/bin/env python3
  2. from benchmarks.communication.run_all import main
  3. from benchmarks.communication.constants import *
  4. from benchmarks.communication.utils import *
  5. import os
  6. import sys
  7. # Run the same file with deepspeed launcher. This is required since setuptools will auto-detect python files and insert a python shebang for both 'scripts' and 'entry_points', and this benchmarks require the DS launcher
  8. required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
  9. if not all(map(lambda v: v in os.environ, required_env)):
  10. import subprocess
  11. subprocess.run("deepspeed $(which ds_bench) " + " ".join(sys.argv[1:]), shell=True)
  12. else:
  13. args = benchmark_parser().parse_args()
  14. rank = args.local_rank
  15. main(args, rank)