ddp_utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from torch.nn.parallel import DistributedDataParallel
  2. from torch.nn.parallel.distributed import _find_tensors
  3. import torch.optim
  4. import torch.utils.data
  5. import torch
  6. from packaging import version
  7. def get_torch_version():
  8. torch_version = torch.__version__
  9. torch_version = torch_version.split("dev")[0]
  10. torch_version = torch_version.split("cu")[0]
  11. if torch_version[-1] == '.':
  12. torch_version = torch_version[:-1]
  13. torch_version = torch_version.replace("+","")
  14. return torch_version
  15. class DDP(DistributedDataParallel):
  16. """
  17. Override the forward call in lightning so it goes to training and validation step respectively
  18. """
  19. def forward(self, *inputs, **kwargs): # pragma: no cover
  20. torch_version = get_torch_version()
  21. if version.parse(torch_version) < version.parse("1.11"):
  22. self._sync_params()
  23. inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  24. assert len(self.device_ids) == 1
  25. if self.module.training:
  26. output = self.module.training_step(*inputs[0], **kwargs[0])
  27. elif self.module.testing:
  28. output = self.module.test_step(*inputs[0], **kwargs[0])
  29. else:
  30. output = self.module.validation_step(*inputs[0], **kwargs[0])
  31. if torch.is_grad_enabled():
  32. # We'll return the output object verbatim since it is a freeform
  33. # object. We need to find any tensors in this object, though,
  34. # because we need to figure out which parameters were used during
  35. # this forward pass, to ensure we short circuit reduction for any
  36. # unused parameters. Only if `find_unused_parameters` is set.
  37. if self.find_unused_parameters:
  38. self.reducer.prepare_for_backward(list(_find_tensors(output)))
  39. else:
  40. self.reducer.prepare_for_backward([])
  41. elif version.parse(torch_version) < version.parse("2.1"):
  42. from torch.nn.parallel.distributed import \
  43. logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
  44. with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
  45. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  46. self.logger.set_runtime_stats_and_log()
  47. self.num_iterations += 1
  48. self.reducer.prepare_for_forward()
  49. # Notify the join context that this process has not joined, if
  50. # needed
  51. work = Join.notify_join_context(self)
  52. if work:
  53. self.reducer._set_forward_pass_work_handle(
  54. work, self._divide_by_initial_world_size
  55. )
  56. # Calling _rebuild_buckets before forward compuation,
  57. # It may allocate new buckets before deallocating old buckets
  58. # inside _rebuild_buckets. To save peak memory usage,
  59. # call _rebuild_buckets before the peak memory usage increases
  60. # during forward computation.
  61. # This should be called only once during whole training period.
  62. if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
  63. logging.info("Reducer buckets have been rebuilt in this iteration.")
  64. self._has_rebuilt_buckets = True
  65. # sync params according to location (before/after forward) user
  66. # specified as part of hook, if hook was specified.
  67. buffer_hook_registered = hasattr(self, 'buffer_hook')
  68. if self._check_sync_bufs_pre_fwd():
  69. self._sync_buffers()
  70. if self._join_config.enable:
  71. # Notify joined ranks whether they should sync in backwards pass or not.
  72. self._check_global_requires_backward_grad_sync(is_joined_rank=False)
  73. inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  74. if self.module.training:
  75. output = self.module.training_step(*inputs[0], **kwargs[0])
  76. elif self.module.testing:
  77. output = self.module.test_step(*inputs[0], **kwargs[0])
  78. else:
  79. output = self.module.validation_step(*inputs[0], **kwargs[0])
  80. # sync params according to location (before/after forward) user
  81. # specified as part of hook, if hook was specified.
  82. if self._check_sync_bufs_post_fwd():
  83. self._sync_buffers()
  84. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  85. self.require_forward_param_sync = True
  86. # We'll return the output object verbatim since it is a freeform
  87. # object. We need to find any tensors in this object, though,
  88. # because we need to figure out which parameters were used during
  89. # this forward pass, to ensure we short circuit reduction for any
  90. # unused parameters. Only if `find_unused_parameters` is set.
  91. if self.find_unused_parameters and not self.static_graph:
  92. # Do not need to populate this for static graph.
  93. self.reducer.prepare_for_backward(list(_find_tensors(output)))
  94. else:
  95. self.reducer.prepare_for_backward([])
  96. else:
  97. self.require_forward_param_sync = False
  98. # TODO: DDPSink is currently enabled for unused parameter detection and
  99. # static graph training for first iteration.
  100. if (self.find_unused_parameters and not self.static_graph) or (
  101. self.static_graph and self.num_iterations == 1
  102. ):
  103. state_dict = {
  104. 'static_graph': self.static_graph,
  105. 'num_iterations': self.num_iterations,
  106. }
  107. output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
  108. output
  109. )
  110. output_placeholders = [None for _ in range(len(output_tensor_list))]
  111. # Do not touch tensors that have no grad_fn, which can cause issues
  112. # such as https://github.com/pytorch/pytorch/issues/60733
  113. for i, output in enumerate(output_tensor_list):
  114. if torch.is_tensor(output) and output.grad_fn is None:
  115. output_placeholders[i] = output
  116. # When find_unused_parameters=True, makes tensors which require grad
  117. # run through the DDPSink backward pass. When not all outputs are
  118. # used in loss, this makes those corresponding tensors receive
  119. # undefined gradient which the reducer then handles to ensure
  120. # param.grad field is not touched and we don't error out.
  121. passthrough_tensor_list = _DDPSink.apply(
  122. self.reducer,
  123. state_dict,
  124. *output_tensor_list,
  125. )
  126. for i in range(len(output_placeholders)):
  127. if output_placeholders[i] is None:
  128. output_placeholders[i] = passthrough_tensor_list[i]
  129. # Reconstruct output data structure.
  130. output = _tree_unflatten_with_rref(
  131. output_placeholders, treespec, output_is_rref
  132. )
  133. else:
  134. output = super().forward(*inputs, **kwargs) # use _run_ddp_forward()
  135. return output
  136. def _run_ddp_forward(self, *inputs, **kwargs):
  137. torch_version = get_torch_version()
  138. if version.parse(torch_version) < version.parse("2.1"):
  139. return super()._run_ddp_forward(*inputs, **kwargs)
  140. with self._inside_ddp_forward():
  141. if self.module.training:
  142. output = self.module.training_step(*inputs, **kwargs)
  143. elif self.module.testing:
  144. output = self.module.test_step(*inputs, **kwargs)
  145. else:
  146. output = self.module.validation_step(*inputs, **kwargs)
  147. return output