common.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. import time
  3. import torch
  4. import torch.distributed as dist
  5. from torch.multiprocessing import Process
  6. import pytest
  7. # Worker timeout *after* the first worker has completed.
  8. DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
  9. def distributed_test(world_size=2, backend='nccl'):
  10. """A decorator for executing a function (e.g., a unit test) in a distributed manner.
  11. This decorator manages the spawning and joining of processes, initialization of
  12. torch.distributed, and catching of errors.
  13. Usage example:
  14. @distributed_test(worker_size=[2,3])
  15. def my_test():
  16. rank = dist.get_rank()
  17. world_size = dist.get_world_size()
  18. assert(rank < world_size)
  19. Arguments:
  20. world_size (int or list): number of ranks to spawn. Can be a list to spawn
  21. multiple tests.
  22. """
  23. def dist_wrap(run_func):
  24. """Second-level decorator for dist_test. This actually wraps the function. """
  25. def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
  26. """Initialize torch.distributed and execute the user function. """
  27. os.environ['MASTER_ADDR'] = '127.0.0.1'
  28. os.environ['MASTER_PORT'] = '29500'
  29. dist.init_process_group(backend=backend,
  30. init_method='env://',
  31. rank=local_rank,
  32. world_size=num_procs)
  33. if torch.cuda.is_available():
  34. torch.cuda.set_device(local_rank)
  35. run_func(*func_args, **func_kwargs)
  36. def dist_launcher(num_procs, *func_args, **func_kwargs):
  37. """Launch processes and gracefully handle failures. """
  38. # Spawn all workers on subprocesses.
  39. processes = []
  40. for local_rank in range(num_procs):
  41. p = Process(target=dist_init,
  42. args=(local_rank,
  43. num_procs,
  44. *func_args),
  45. kwargs=func_kwargs)
  46. p.start()
  47. processes.append(p)
  48. # Now loop and wait for a test to complete. The spin-wait here isn't a big
  49. # deal because the number of processes will be O(#GPUs) << O(#CPUs).
  50. any_done = False
  51. while not any_done:
  52. for p in processes:
  53. if not p.is_alive():
  54. any_done = True
  55. break
  56. # Wait for all other processes to complete
  57. for p in processes:
  58. p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)
  59. failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
  60. for rank, p in failed:
  61. # If it still hasn't terminated, kill it because it hung.
  62. if p.exitcode is None:
  63. p.terminate()
  64. pytest.fail(f'Worker {rank} hung.', pytrace=False)
  65. if p.exitcode < 0:
  66. pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
  67. pytrace=False)
  68. if p.exitcode > 0:
  69. pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
  70. pytrace=False)
  71. def run_func_decorator(*func_args, **func_kwargs):
  72. """Entry point for @distributed_test(). """
  73. if isinstance(world_size, int):
  74. dist_launcher(world_size, *func_args, **func_kwargs)
  75. elif isinstance(world_size, list):
  76. for procs in world_size:
  77. dist_launcher(procs, *func_args, **func_kwargs)
  78. time.sleep(0.5)
  79. else:
  80. raise TypeError(f'world_size must be an integer or a list of integers.')
  81. return run_func_decorator
  82. return dist_wrap