torch.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import utils
  5. from .utils import *
  6. from .backend import *
  7. from .comm import *
  8. import os
  9. DS_COMM_ALL_GATHER_OFF = False
  10. DS_COMM_REDUCE_SCATTER_OFF = False
  11. DS_COMM_BROADCAST_OFF = False
  12. DS_COMM_ALL_REDUCE_OFF = False
  13. DS_COMM_REDUCE_OFF = False
  14. def is_torch_two():
  15. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  16. if TORCH_MAJOR >= 2:
  17. return True
  18. else:
  19. return False
  20. def torch_ver_ge_1_13():
  21. if is_torch_two():
  22. return True
  23. else:
  24. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  25. assert TORCH_MAJOR == 1
  26. TORCH_MINOR = int(torch.__version__.split('.')[1])
  27. return TORCH_MINOR >= 13
  28. def has_coalescing_manager():
  29. has_c10d = hasattr(torch.distributed, 'distributed_c10d')
  30. return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
  31. def has_all_reduce_coalesced():
  32. return hasattr(torch.distributed, "all_reduce_coalesced") and torch_ver_ge_1_13()
  33. def get_coalescing_manager(group, device, reqs):
  34. if is_torch_two():
  35. return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
  36. else:
  37. return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
  38. ##Utilities to turn comm off
  39. ##TODO: move to base comm (wrapper)
  40. def all_gather_comm_off(flag=False):
  41. global DS_COMM_ALL_GATHER_OFF
  42. DS_COMM_ALL_GATHER_OFF = flag
  43. def reduce_scatter_comm_off(flag=False):
  44. global DS_COMM_REDUCE_SCATTER_OFF
  45. DS_COMM_REDUCE_SCATTER_OFF = flag
  46. def broadcast_comm_off(flag=False):
  47. global DS_COMM_BROADCAST_OFF
  48. DS_COMM_BROADCAST_OFF = flag
  49. def all_reduce_comm_off(flag=False):
  50. global DS_COMM_ALL_REDUCE_OFF
  51. DS_COMM_ALL_REDUCE_OFF = flag
  52. def reduce_comm_off(flag=False):
  53. global DS_COMM_REDUCE_OFF
  54. DS_COMM_REDUCE_OFF = flag
  55. #assumption: all_gather and reduce scatter
  56. ## are what we care about
  57. def backward_comm_off(flag=False):
  58. all_gather_comm_off(flag)
  59. reduce_scatter_comm_off(flag)
  60. class Noop:
  61. def wait(self):
  62. return None
  63. class TorchBackend(Backend):
  64. """
  65. A light-weight wrapper class for torch.distributed API.
  66. Only a subset of functions are wrapped. Once the init_process_group
  67. is initialized, standard torch.distributed.* can be used directly
  68. so no need to wrap all the functions. We can keep adding wrappers as
  69. needed.
  70. """
  71. def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
  72. super(TorchBackend, self).__init__()
  73. self.has_all_reduce_coalesced = has_all_reduce_coalesced()
  74. self.has_coalescing_manager = has_coalescing_manager()
  75. self.all_gather_function = self.get_all_gather_function()
  76. self.reduce_scatter_function = self.get_reduce_scatter_function()
  77. self.initialized = True
  78. self.name = name
  79. # Future functionality to support ds.initialize() on a single GPU
  80. # The idea is to fake that dist backend is initialized even when
  81. # it is not so we can run on a single GPU without doing any init_process_group
  82. self.single_gpu_mode = True
  83. self.init_process_group(backend, timeout, init_method, rank, world_size)
  84. @classmethod
  85. def get_all_gather_function(self):
  86. if hasattr(torch.distributed, "all_gather_into_tensor"):
  87. return torch.distributed.all_gather_into_tensor
  88. elif hasattr(torch.distributed, "_all_gather_base"):
  89. return torch.distributed._all_gather_base
  90. return None
  91. @classmethod
  92. def get_reduce_scatter_function(self):
  93. if hasattr(torch.distributed, "reduce_scatter_tensor"):
  94. return torch.distributed.reduce_scatter_tensor
  95. elif hasattr(torch.distributed, "_reduce_scatter_base"):
  96. return torch.distributed._reduce_scatter_base
  97. return None
  98. def has_all_gather_into_tensor(self):
  99. return self.all_gather_function is not None
  100. def has_reduce_scatter_tensor(self):
  101. return self.reduce_scatter_function is not None
  102. def init_process_group(self, backend, timeout, init_method, rank, world_size):
  103. if not torch.distributed.is_initialized():
  104. torch.distributed.init_process_group(backend,
  105. timeout=timeout,
  106. init_method=init_method,
  107. rank=rank,
  108. world_size=world_size)
  109. self.using_mpi = torch.distributed.get_backend() == 'mpi'
  110. def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  111. op = self._reduce_op(op)
  112. return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  113. def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  114. op = self._reduce_op(op)
  115. return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  116. def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  117. """ proxy func to torch.distributed.all_reduce_coalesced,
  118. which is included in PyTorch 1.13 and above
  119. """
  120. if not self.has_all_reduce_coalesced:
  121. raise RuntimeError(f"Current torch version does not have all_reduce_coalesced "
  122. f"api (torch.__version__: {torch.__version__})")
  123. op = self._reduce_op(op)
  124. return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
  125. def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
  126. if DS_COMM_REDUCE_OFF:
  127. if int(os.getenv('RANK', '0')) == 0:
  128. utils.logger.warning("REDUCE is OFF")
  129. return Noop()
  130. return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
  131. def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
  132. if DS_COMM_REDUCE_SCATTER_OFF:
  133. if int(os.getenv('RANK', '0')) == 0:
  134. utils.logger.warning("REDUCE SCATTER is OFF")
  135. return Noop()
  136. else:
  137. return torch.distributed.reduce_scatter(output=output,
  138. input_list=input_list,
  139. op=self._reduce_op(op),
  140. group=group,
  141. async_op=async_op)
  142. def broadcast(self, tensor, src, group=None, async_op=False):
  143. if DS_COMM_BROADCAST_OFF:
  144. if int(os.getenv('RANK', '0')) == 0:
  145. utils.logger.warning("BROADCAST is OFF")
  146. return Noop()
  147. else:
  148. return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
  149. def all_gather(self, tensor_list, tensor, group=None, async_op=False):
  150. if DS_COMM_ALL_GATHER_OFF:
  151. if int(os.getenv('RANK', '0')) == 0:
  152. utils.logger.warning("All Gather is OFF")
  153. return Noop()
  154. else:
  155. return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
  156. def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
  157. if self.has_all_gather_into_tensor():
  158. return self.all_gather_function(output_tensor=output_tensor,
  159. input_tensor=input_tensor,
  160. group=group,
  161. async_op=async_op)
  162. def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
  163. if DS_COMM_ALL_GATHER_OFF:
  164. if int(os.getenv('RANK', '0')) == 0:
  165. utils.logger.warning("All Gather is OFF")
  166. return Noop()
  167. else:
  168. if self.has_allgather_base:
  169. return torch.distributed.distributed_c10d._all_gather_base(output_tensor=output_tensor,
  170. input_tensor=input_tensor,
  171. group=group,
  172. async_op=async_op)
  173. else:
  174. utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
  175. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  176. "please consider upgrading your pytorch installation.")
  177. pass
  178. def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
  179. """"""
  180. assert len(output_tensors) == len(input_tensors), ""
  181. if hasattr(torch.distributed.distributed_c10d, '_all_gather_base_coalesced'):
  182. # customized PyTorch
  183. return torch.distributed.distributed_c10d._all_gather_base_coalesced(output_tensors,
  184. input_tensors,
  185. group=group,
  186. async_op=async_op)
  187. elif has_coalescing_manager():
  188. reqs = []
  189. with get_coalescing_manager(group, input_tensors[0].device, reqs):
  190. for output, input in zip(output_tensors, input_tensors):
  191. handle = torch.distributed.distributed_c10d.all_gather_into_tensor(output,
  192. input,
  193. group=group,
  194. async_op=True)
  195. reqs.append(handle)
  196. if async_op:
  197. return reqs[-1]
  198. else:
  199. reqs[-1].wait()
  200. def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
  201. if self.has_reduce_scatter_tensor():
  202. return self.reduce_scatter_function(output_tensor,
  203. input_tensor,
  204. op=self._reduce_op(op),
  205. group=group,
  206. async_op=async_op)
  207. else:
  208. utils.logger.warning("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
  209. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  210. "please consider upgrading your pytorch installation.")
  211. pass
  212. def all_to_all_single(self,
  213. output,
  214. input,
  215. output_split_sizes=None,
  216. input_split_sizes=None,
  217. group=None,
  218. async_op=False):
  219. return torch.distributed.all_to_all_single(output=output,
  220. input=input,
  221. output_split_sizes=output_split_sizes,
  222. input_split_sizes=input_split_sizes,
  223. group=group,
  224. async_op=async_op)
  225. def send(self, tensor, dst, group=None, tag=0):
  226. return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
  227. def recv(self, tensor, src=None, group=None, tag=0):
  228. return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)
  229. def isend(self, tensor, dst, group=None, tag=0):
  230. return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)
  231. def irecv(self, tensor, src=None, group=None, tag=0):
  232. return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)
  233. def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
  234. return torch.distributed.gather(tensor=tensor,
  235. gather_list=gather_list,
  236. dst=dst,
  237. group=group,
  238. async_op=async_op)
  239. def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
  240. return torch.distributed.scatter(tensor=tensor,
  241. scatter_list=scatter_list,
  242. src=src,
  243. group=group,
  244. async_op=async_op)
  245. def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
  246. if group is None:
  247. group = torch.distributed.GroupMember.WORLD
  248. return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)
  249. def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
  250. if group is None:
  251. group = torch.distributed.GroupMember.WORLD
  252. return torch.distributed.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
  253. def get_rank(self, group=None):
  254. return torch.distributed.get_rank(group=group)
  255. def get_world_size(self, group=None):
  256. return torch.distributed.get_world_size(group=group)
  257. def is_initialized(self):
  258. return torch.distributed.is_initialized()
  259. def get_backend(self, group=None):
  260. return torch.distributed.get_backend(group=group)
  261. def new_group(self, ranks):
  262. return torch.distributed.new_group(ranks)
  263. def get_global_rank(self, group, group_rank):
  264. if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
  265. from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
  266. else:
  267. from torch.distributed.distributed_c10d import _get_global_rank
  268. return _get_global_rank(group, group_rank)
  269. def get_world_group(self):
  270. return torch.distributed.group.WORLD
  271. def destroy_process_group(self, group=None):
  272. return torch.distributed.destroy_process_group(group=group)
  273. def _reduce_op(self, op):
  274. '''
  275. Helper function. If the op provided is not a torch.dist.ReduceOp, convert it and return
  276. '''
  277. if not isinstance(op, torch.distributed.ReduceOp):
  278. if op == ReduceOp.SUM:
  279. op = torch.distributed.ReduceOp.SUM
  280. elif op == ReduceOp.PRODUCT:
  281. op = torch.distributed.ReduceOp.PRODUCT
  282. elif op == ReduceOp.AVG:
  283. op = torch.distributed.ReduceOp.AVG
  284. elif op == ReduceOp.MIN:
  285. op = torch.distributed.ReduceOp.MIN
  286. elif op == ReduceOp.MAX:
  287. op = torch.distributed.ReduceOp.MAX
  288. elif op == ReduceOp.BAND:
  289. op = torch.distributed.ReduceOp.BAND
  290. elif op == ReduceOp.BOR:
  291. op = torch.distributed.ReduceOp.BOR
  292. elif op == ReduceOp.BXOR:
  293. op = torch.distributed.ReduceOp.BXOR
  294. return op
  295. # This will become a light-weight wrapper around torch.distributed functions
  296. # TODO: create some example to show how this wrapper can help profile communication
  297. # TODO: make sure there is no performance regression with this approach
  298. # TODO: explore monkey-patching if this does not work