torch.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. from deepspeed import utils
  5. from .utils import *
  6. from .backend import *
  7. from .comm import *
  8. class TorchBackend(Backend):
  9. """
  10. A light-weight wrapper class for torch.distributed API.
  11. Only a subset of functions are wrapped. Once the init_process_group
  12. is initialized, standard torch.distributed.* can be used directly
  13. so no need to wrap all the functions. We can keep adding wrappers as
  14. needed.
  15. """
  16. def __init__(self,
  17. backend,
  18. timeout,
  19. init_method,
  20. rank=-1,
  21. world_size=-1,
  22. name='torch'):
  23. super(TorchBackend, self).__init__()
  24. self.torch_version_before_18 = older_torch()
  25. self.has_allgather_base = has_allgather_base()
  26. self.has_reduce_scatter_base = has_reduce_scatter_base()
  27. self.initialized = True
  28. self.name = name
  29. # Future functionality to support ds.initialize() on a single GPU
  30. # The idea is to fake that dist backend is initialized even when
  31. # it is not so we can run on a single GPU without doing any init_process_group
  32. self.single_gpu_mode = True
  33. self.init_process_group(backend, timeout, init_method, rank, world_size)
  34. def init_process_group(self, backend, timeout, init_method, rank, world_size):
  35. if not torch.distributed.is_initialized():
  36. torch.distributed.init_process_group(backend,
  37. timeout=timeout,
  38. init_method=init_method,
  39. rank=rank,
  40. world_size=world_size)
  41. self.using_mpi = torch.distributed.get_backend() == 'mpi'
  42. def all_reduce(self,
  43. tensor,
  44. op=torch.distributed.ReduceOp.SUM,
  45. group=None,
  46. async_op=False):
  47. op = self._reduce_op(op)
  48. return torch.distributed.all_reduce(tensor=tensor,
  49. op=op,
  50. group=group,
  51. async_op=async_op)
  52. def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
  53. return torch.distributed.reduce(tensor=tensor,
  54. dst=dst,
  55. op=self._reduce_op(op),
  56. group=group,
  57. async_op=async_op)
  58. def reduce_scatter(self,
  59. output,
  60. input_list,
  61. op=ReduceOp.SUM,
  62. group=None,
  63. async_op=False):
  64. return torch.distributed.reduce_scatter(output=output,
  65. input_list=input_list,
  66. op=self._reduce_op(op),
  67. group=group,
  68. async_op=async_op)
  69. def broadcast(self, tensor, src, group=None, async_op=False):
  70. return torch.distributed.broadcast(tensor=tensor,
  71. src=src,
  72. group=group,
  73. async_op=async_op)
  74. def all_gather(self, tensor_list, tensor, group=None, async_op=False):
  75. return torch.distributed.all_gather(tensor_list=tensor_list,
  76. tensor=tensor,
  77. group=group,
  78. async_op=async_op)
  79. def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
  80. if self.has_allgather_base:
  81. return torch.distributed.distributed_c10d._all_gather_base(
  82. output_tensor=output_tensor,
  83. input_tensor=input_tensor,
  84. group=group,
  85. async_op=async_op)
  86. else:
  87. utils.logger.warning(
  88. "unable to find torch.distributed._all_gather_base. will fall back to "
  89. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  90. "please consider upgrading your pytorch installation.")
  91. pass
  92. def reduce_scatter_base(self,
  93. output_tensor,
  94. input_tensor,
  95. op=ReduceOp.SUM,
  96. group=None,
  97. async_op=False):
  98. if self.has_reduce_scatter_base:
  99. return torch.distributed._reduce_scatter_base(output_tensor,
  100. input_tensor,
  101. op=self._reduce_op(op),
  102. group=group,
  103. async_op=async_op)
  104. else:
  105. utils.logger.warning(
  106. "unable to find torch.distributed._reduce_scatter_base. will fall back to "
  107. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  108. "please consider upgrading your pytorch installation.")
  109. pass
  110. def all_to_all_single(self,
  111. output,
  112. input,
  113. output_split_sizes=None,
  114. input_split_sizes=None,
  115. group=None,
  116. async_op=False):
  117. return torch.distributed.all_to_all_single(output=output,
  118. input=input,
  119. output_split_sizes=output_split_sizes,
  120. input_split_sizes=input_split_sizes,
  121. group=group,
  122. async_op=async_op)
  123. def send(self, tensor, dst, group=None, tag=0):
  124. return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
  125. def recv(self, tensor, src=None, group=None, tag=0):
  126. return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)
  127. def isend(self, tensor, dst, group=None, tag=0):
  128. return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)
  129. def irecv(self, tensor, src=None, group=None, tag=0):
  130. return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)
  131. def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
  132. return torch.distributed.gather(tensor=tensor,
  133. gather_list=gather_list,
  134. dst=dst,
  135. group=group,
  136. async_op=async_op)
  137. def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
  138. return torch.distributed.scatter(tensor=tensor,
  139. scatter_list=scatter_list,
  140. src=src,
  141. group=group,
  142. async_op=async_op)
  143. def barrier(self,
  144. group=torch.distributed.GroupMember.WORLD,
  145. async_op=False,
  146. device_ids=None):
  147. if group is None:
  148. group = torch.distributed.GroupMember.WORLD
  149. return torch.distributed.barrier(group=group,
  150. async_op=async_op,
  151. device_ids=device_ids)
  152. def monitored_barrier(self,
  153. group=torch.distributed.GroupMember.WORLD,
  154. timeout=None,
  155. wait_all_ranks=False):
  156. if group is None:
  157. group = torch.distributed.GroupMember.WORLD
  158. return torch.distributed.monitored_barrier(group=group,
  159. timeout=timeout,
  160. wait_all_ranks=wait_all_ranks)
  161. def get_rank(self, group=None):
  162. return torch.distributed.get_rank(group=group)
  163. def get_world_size(self, group=None):
  164. return torch.distributed.get_world_size(group=group)
  165. def is_initialized(self):
  166. return torch.distributed.is_initialized()
  167. def get_backend(self, group=None):
  168. return torch.distributed.get_backend(group=group)
  169. def new_group(self, ranks):
  170. return torch.distributed.new_group(ranks)
  171. def get_global_rank(self, group, group_rank):
  172. if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
  173. from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
  174. else:
  175. from torch.distributed.distributed_c10d import _get_global_rank
  176. return _get_global_rank(group, group_rank)
  177. def get_world_group(self):
  178. return torch.distributed.group.WORLD
  179. def destroy_process_group(self, group=None):
  180. return torch.distributed.destroy_process_group(group=group)
  181. def _reduce_op(self, op):
  182. '''
  183. Helper function. If the op provided is not a torch.dist.ReduceOp, convert it and return
  184. '''
  185. if not isinstance(op, torch.distributed.ReduceOp):
  186. if op == ReduceOp.SUM:
  187. op = torch.distributed.ReduceOp.SUM
  188. elif op == ReduceOp.PRODUCT:
  189. op = torch.distributed.ReduceOp.PRODUCT
  190. elif op == ReduceOp.AVG:
  191. op = torch.distributed.ReduceOp.AVG
  192. elif op == ReduceOp.MIN:
  193. op = torch.distributed.ReduceOp.MIN
  194. elif op == ReduceOp.MAX:
  195. op = torch.distributed.ReduceOp.MAX
  196. elif op == ReduceOp.BAND:
  197. op = torch.distributed.ReduceOp.BAND
  198. elif op == ReduceOp.BOR:
  199. op = torch.distributed.ReduceOp.BOR
  200. elif op == ReduceOp.BXOR:
  201. op = torch.distributed.ReduceOp.BXOR
  202. return op
  203. # This will become a light-weight wrapper around torch.distributed functions
  204. # TODO: create some example to show how this wrapper can help profile communication
  205. # TODO: make sure there is no performance regression with this approach
  206. # TODO: explore monkey-patching if this does not work