torch.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import utils
  5. from .utils import *
  6. from .backend import *
  7. from .comm import *
  8. import os
  9. DS_COMM_ALL_GATHER_OFF = False
  10. DS_COMM_REDUCE_SCATTER_OFF = False
  11. DS_COMM_BROADCAST_OFF = False
  12. DS_COMM_ALL_REDUCE_OFF = False
  13. DS_COMM_REDUCE_OFF = False
  14. def is_torch_ver_eq_2_0():
  15. TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
  16. if TORCH_MAJOR == 2 and TORCH_MINOR == 0:
  17. return True
  18. return False
  19. def is_torch_ver_ge_2_1():
  20. TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
  21. if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1:
  22. return True
  23. return False
  24. def torch_ver_ge_1_13():
  25. TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
  26. if TORCH_MAJOR >= 1 and TORCH_MINOR >= 13:
  27. return True
  28. return False
  29. def has_coalescing_manager():
  30. has_c10d = hasattr(torch.distributed, 'distributed_c10d')
  31. return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
  32. def has_all_reduce_coalesced():
  33. return hasattr(torch.distributed, "all_reduce_coalesced") and torch_ver_ge_1_13()
  34. def get_coalescing_manager(group, device, reqs, async_op):
  35. if is_torch_ver_eq_2_0():
  36. return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
  37. elif is_torch_ver_ge_2_1():
  38. return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, async_ops=async_op)
  39. else:
  40. return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
  41. ##Utilities to turn comm off
  42. ##TODO: move to base comm (wrapper)
  43. def all_gather_comm_off(flag=False):
  44. global DS_COMM_ALL_GATHER_OFF
  45. DS_COMM_ALL_GATHER_OFF = flag
  46. def reduce_scatter_comm_off(flag=False):
  47. global DS_COMM_REDUCE_SCATTER_OFF
  48. DS_COMM_REDUCE_SCATTER_OFF = flag
  49. def broadcast_comm_off(flag=False):
  50. global DS_COMM_BROADCAST_OFF
  51. DS_COMM_BROADCAST_OFF = flag
  52. def all_reduce_comm_off(flag=False):
  53. global DS_COMM_ALL_REDUCE_OFF
  54. DS_COMM_ALL_REDUCE_OFF = flag
  55. def reduce_comm_off(flag=False):
  56. global DS_COMM_REDUCE_OFF
  57. DS_COMM_REDUCE_OFF = flag
  58. #assumption: all_gather and reduce scatter
  59. ## are what we care about
  60. def backward_comm_off(flag=False):
  61. all_gather_comm_off(flag)
  62. reduce_scatter_comm_off(flag)
  63. class Noop:
  64. def wait(self):
  65. return None
  66. class TorchBackend(Backend):
  67. """
  68. A light-weight wrapper class for torch.distributed API.
  69. Only a subset of functions are wrapped. Once the init_process_group
  70. is initialized, standard torch.distributed.* can be used directly
  71. so no need to wrap all the functions. We can keep adding wrappers as
  72. needed.
  73. """
  74. def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
  75. super(TorchBackend, self).__init__()
  76. self.has_all_reduce_coalesced = has_all_reduce_coalesced()
  77. self.has_coalescing_manager = has_coalescing_manager()
  78. self.all_gather_function = self.get_all_gather_function()
  79. self.reduce_scatter_function = self.get_reduce_scatter_function()
  80. self.initialized = True
  81. self.name = name
  82. # Future functionality to support ds.initialize() on a single GPU
  83. # The idea is to fake that dist backend is initialized even when
  84. # it is not so we can run on a single GPU without doing any init_process_group
  85. self.single_gpu_mode = True
  86. self.init_process_group(backend, timeout, init_method, rank, world_size)
  87. @classmethod
  88. def get_all_gather_function(self):
  89. if hasattr(torch.distributed, "all_gather_into_tensor"):
  90. return torch.distributed.all_gather_into_tensor
  91. elif hasattr(torch.distributed, "_all_gather_base"):
  92. return torch.distributed._all_gather_base
  93. return None
  94. @classmethod
  95. def get_reduce_scatter_function(self):
  96. if hasattr(torch.distributed, "reduce_scatter_tensor"):
  97. return torch.distributed.reduce_scatter_tensor
  98. elif hasattr(torch.distributed, "_reduce_scatter_base"):
  99. return torch.distributed._reduce_scatter_base
  100. return None
  101. def has_all_gather_into_tensor(self):
  102. return self.all_gather_function is not None
  103. def has_reduce_scatter_tensor(self):
  104. return self.reduce_scatter_function is not None
  105. def init_process_group(self, backend, timeout, init_method, rank, world_size):
  106. if not torch.distributed.is_initialized():
  107. torch.distributed.init_process_group(backend,
  108. timeout=timeout,
  109. init_method=init_method,
  110. rank=rank,
  111. world_size=world_size)
  112. self.using_mpi = torch.distributed.get_backend() == 'mpi'
  113. def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  114. op = self._reduce_op(op)
  115. return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  116. def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  117. op = self._reduce_op(op)
  118. return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  119. def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  120. """ proxy func to torch.distributed.all_reduce_coalesced,
  121. which is included in PyTorch 1.13 and above
  122. """
  123. if not self.has_all_reduce_coalesced:
  124. raise RuntimeError(f"Current torch version does not have all_reduce_coalesced "
  125. f"api (torch.__version__: {torch.__version__})")
  126. op = self._reduce_op(op)
  127. return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
  128. def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
  129. if DS_COMM_REDUCE_OFF:
  130. if int(os.getenv('RANK', '0')) == 0:
  131. utils.logger.warning("REDUCE is OFF")
  132. return Noop()
  133. return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
  134. def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
  135. if DS_COMM_REDUCE_SCATTER_OFF:
  136. if int(os.getenv('RANK', '0')) == 0:
  137. utils.logger.warning("REDUCE SCATTER is OFF")
  138. return Noop()
  139. else:
  140. return torch.distributed.reduce_scatter(output=output,
  141. input_list=input_list,
  142. op=self._reduce_op(op),
  143. group=group,
  144. async_op=async_op)
  145. def broadcast(self, tensor, src, group=None, async_op=False):
  146. if DS_COMM_BROADCAST_OFF:
  147. if int(os.getenv('RANK', '0')) == 0:
  148. utils.logger.warning("BROADCAST is OFF")
  149. return Noop()
  150. else:
  151. return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
  152. def all_gather(self, tensor_list, tensor, group=None, async_op=False):
  153. if DS_COMM_ALL_GATHER_OFF:
  154. if int(os.getenv('RANK', '0')) == 0:
  155. utils.logger.warning("All Gather is OFF")
  156. return Noop()
  157. else:
  158. return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
  159. def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
  160. if self.has_all_gather_into_tensor():
  161. return self.all_gather_function(output_tensor=output_tensor,
  162. input_tensor=input_tensor,
  163. group=group,
  164. async_op=async_op)
  165. def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
  166. if DS_COMM_ALL_GATHER_OFF:
  167. if int(os.getenv('RANK', '0')) == 0:
  168. utils.logger.warning("All Gather is OFF")
  169. return Noop()
  170. else:
  171. if self.has_allgather_base:
  172. return torch.distributed.distributed_c10d._all_gather_base(output_tensor=output_tensor,
  173. input_tensor=input_tensor,
  174. group=group,
  175. async_op=async_op)
  176. else:
  177. utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
  178. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  179. "please consider upgrading your pytorch installation.")
  180. pass
  181. def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
  182. """"""
  183. assert len(output_tensors) == len(input_tensors), ""
  184. if hasattr(torch.distributed.distributed_c10d, '_all_gather_base_coalesced'):
  185. # customized PyTorch
  186. return torch.distributed.distributed_c10d._all_gather_base_coalesced(output_tensors,
  187. input_tensors,
  188. group=group,
  189. async_op=async_op)
  190. elif has_coalescing_manager():
  191. reqs = []
  192. with get_coalescing_manager(group, input_tensors[0].device, reqs, async_op):
  193. for output, input in zip(output_tensors, input_tensors):
  194. handle = torch.distributed.distributed_c10d.all_gather_into_tensor(output,
  195. input,
  196. group=group,
  197. async_op=True)
  198. reqs.append(handle)
  199. if async_op:
  200. return reqs[-1]
  201. else:
  202. reqs[-1].wait()
  203. def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
  204. if self.has_reduce_scatter_tensor():
  205. return self.reduce_scatter_function(output_tensor,
  206. input_tensor,
  207. op=self._reduce_op(op),
  208. group=group,
  209. async_op=async_op)
  210. else:
  211. utils.logger.warning("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
  212. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  213. "please consider upgrading your pytorch installation.")
  214. pass
  215. def all_to_all_single(self,
  216. output,
  217. input,
  218. output_split_sizes=None,
  219. input_split_sizes=None,
  220. group=None,
  221. async_op=False):
  222. return torch.distributed.all_to_all_single(output=output,
  223. input=input,
  224. output_split_sizes=output_split_sizes,
  225. input_split_sizes=input_split_sizes,
  226. group=group,
  227. async_op=async_op)
  228. def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
  229. return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)
  230. def send(self, tensor, dst, group=None, tag=0):
  231. return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
  232. def recv(self, tensor, src=None, group=None, tag=0):
  233. return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)
  234. def isend(self, tensor, dst, group=None, tag=0):
  235. return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)
  236. def irecv(self, tensor, src=None, group=None, tag=0):
  237. return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)
  238. def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
  239. return torch.distributed.gather(tensor=tensor,
  240. gather_list=gather_list,
  241. dst=dst,
  242. group=group,
  243. async_op=async_op)
  244. def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
  245. return torch.distributed.scatter(tensor=tensor,
  246. scatter_list=scatter_list,
  247. src=src,
  248. group=group,
  249. async_op=async_op)
  250. def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
  251. if group is None:
  252. group = torch.distributed.GroupMember.WORLD
  253. return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)
  254. def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
  255. if group is None:
  256. group = torch.distributed.GroupMember.WORLD
  257. return torch.distributed.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
  258. def get_rank(self, group=None):
  259. return torch.distributed.get_rank(group=group)
  260. def get_world_size(self, group=None):
  261. return torch.distributed.get_world_size(group=group)
  262. def is_initialized(self):
  263. return torch.distributed.is_initialized()
  264. def get_backend(self, group=None):
  265. return torch.distributed.get_backend(group=group)
  266. def new_group(self, ranks):
  267. return torch.distributed.new_group(ranks)
  268. def get_global_rank(self, group, group_rank):
  269. if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
  270. from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
  271. else:
  272. from torch.distributed.distributed_c10d import _get_global_rank
  273. return _get_global_rank(group, group_rank)
  274. def get_world_group(self):
  275. return torch.distributed.group.WORLD
  276. def destroy_process_group(self, group=None):
  277. return torch.distributed.destroy_process_group(group=group)
  278. def _reduce_op(self, op):
  279. '''
  280. Helper function. If the op provided is not a torch.dist.ReduceOp, convert it and return
  281. '''
  282. if not isinstance(op, torch.distributed.ReduceOp):
  283. if op == ReduceOp.SUM:
  284. op = torch.distributed.ReduceOp.SUM
  285. elif op == ReduceOp.PRODUCT:
  286. op = torch.distributed.ReduceOp.PRODUCT
  287. elif op == ReduceOp.AVG:
  288. op = torch.distributed.ReduceOp.AVG
  289. elif op == ReduceOp.MIN:
  290. op = torch.distributed.ReduceOp.MIN
  291. elif op == ReduceOp.MAX:
  292. op = torch.distributed.ReduceOp.MAX
  293. elif op == ReduceOp.BAND:
  294. op = torch.distributed.ReduceOp.BAND
  295. elif op == ReduceOp.BOR:
  296. op = torch.distributed.ReduceOp.BOR
  297. elif op == ReduceOp.BXOR:
  298. op = torch.distributed.ReduceOp.BXOR
  299. return op
  300. # This will become a light-weight wrapper around torch.distributed functions
  301. # TODO: create some example to show how this wrapper can help profile communication
  302. # TODO: make sure there is no performance regression with this approach
  303. # TODO: explore monkey-patching if this does not work