123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- from torch.nn.parallel import DistributedDataParallel
- from torch.nn.parallel.distributed import _find_tensors
- import torch.optim
- import torch.utils.data
- import torch
- from packaging import version
- def get_torch_version():
- torch_version = torch.__version__
- torch_version = torch_version.split("dev")[0]
- torch_version = torch_version.split("cu")[0]
- if torch_version[-1] == '.':
- torch_version = torch_version[:-1]
- torch_version = torch_version.replace("+","")
- return torch_version
-
- class DDP(DistributedDataParallel):
- """
- Override the forward call in lightning so it goes to training and validation step respectively
- """
- def forward(self, *inputs, **kwargs): # pragma: no cover
- torch_version = get_torch_version()
- if version.parse(torch_version) < version.parse("1.11"):
- self._sync_params()
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- assert len(self.device_ids) == 1
- if self.module.training:
- output = self.module.training_step(*inputs[0], **kwargs[0])
- elif self.module.testing:
- output = self.module.test_step(*inputs[0], **kwargs[0])
- else:
- output = self.module.validation_step(*inputs[0], **kwargs[0])
- if torch.is_grad_enabled():
- # We'll return the output object verbatim since it is a freeform
- # object. We need to find any tensors in this object, though,
- # because we need to figure out which parameters were used during
- # this forward pass, to ensure we short circuit reduction for any
- # unused parameters. Only if `find_unused_parameters` is set.
- if self.find_unused_parameters:
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
- else:
- self.reducer.prepare_for_backward([])
- elif version.parse(torch_version) < version.parse("2.1"):
- from torch.nn.parallel.distributed import \
- logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
- with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
- if torch.is_grad_enabled() and self.require_backward_grad_sync:
- self.logger.set_runtime_stats_and_log()
- self.num_iterations += 1
- self.reducer.prepare_for_forward()
- # Notify the join context that this process has not joined, if
- # needed
- work = Join.notify_join_context(self)
- if work:
- self.reducer._set_forward_pass_work_handle(
- work, self._divide_by_initial_world_size
- )
- # Calling _rebuild_buckets before forward compuation,
- # It may allocate new buckets before deallocating old buckets
- # inside _rebuild_buckets. To save peak memory usage,
- # call _rebuild_buckets before the peak memory usage increases
- # during forward computation.
- # This should be called only once during whole training period.
- if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
- logging.info("Reducer buckets have been rebuilt in this iteration.")
- self._has_rebuilt_buckets = True
- # sync params according to location (before/after forward) user
- # specified as part of hook, if hook was specified.
- buffer_hook_registered = hasattr(self, 'buffer_hook')
- if self._check_sync_bufs_pre_fwd():
- self._sync_buffers()
- if self._join_config.enable:
- # Notify joined ranks whether they should sync in backwards pass or not.
- self._check_global_requires_backward_grad_sync(is_joined_rank=False)
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- if self.module.training:
- output = self.module.training_step(*inputs[0], **kwargs[0])
- elif self.module.testing:
- output = self.module.test_step(*inputs[0], **kwargs[0])
- else:
- output = self.module.validation_step(*inputs[0], **kwargs[0])
- # sync params according to location (before/after forward) user
- # specified as part of hook, if hook was specified.
- if self._check_sync_bufs_post_fwd():
- self._sync_buffers()
- if torch.is_grad_enabled() and self.require_backward_grad_sync:
- self.require_forward_param_sync = True
- # We'll return the output object verbatim since it is a freeform
- # object. We need to find any tensors in this object, though,
- # because we need to figure out which parameters were used during
- # this forward pass, to ensure we short circuit reduction for any
- # unused parameters. Only if `find_unused_parameters` is set.
- if self.find_unused_parameters and not self.static_graph:
- # Do not need to populate this for static graph.
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
- else:
- self.reducer.prepare_for_backward([])
- else:
- self.require_forward_param_sync = False
- # TODO: DDPSink is currently enabled for unused parameter detection and
- # static graph training for first iteration.
- if (self.find_unused_parameters and not self.static_graph) or (
- self.static_graph and self.num_iterations == 1
- ):
- state_dict = {
- 'static_graph': self.static_graph,
- 'num_iterations': self.num_iterations,
- }
- output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
- output
- )
- output_placeholders = [None for _ in range(len(output_tensor_list))]
- # Do not touch tensors that have no grad_fn, which can cause issues
- # such as https://github.com/pytorch/pytorch/issues/60733
- for i, output in enumerate(output_tensor_list):
- if torch.is_tensor(output) and output.grad_fn is None:
- output_placeholders[i] = output
- # When find_unused_parameters=True, makes tensors which require grad
- # run through the DDPSink backward pass. When not all outputs are
- # used in loss, this makes those corresponding tensors receive
- # undefined gradient which the reducer then handles to ensure
- # param.grad field is not touched and we don't error out.
- passthrough_tensor_list = _DDPSink.apply(
- self.reducer,
- state_dict,
- *output_tensor_list,
- )
- for i in range(len(output_placeholders)):
- if output_placeholders[i] is None:
- output_placeholders[i] = passthrough_tensor_list[i]
- # Reconstruct output data structure.
- output = _tree_unflatten_with_rref(
- output_placeholders, treespec, output_is_rref
- )
- else:
- output = super().forward(*inputs, **kwargs) # use _run_ddp_forward()
- return output
- def _run_ddp_forward(self, *inputs, **kwargs):
- torch_version = get_torch_version()
- if version.parse(torch_version) < version.parse("2.1"):
- return super()._run_ddp_forward(*inputs, **kwargs)
- with self._inside_ddp_forward():
- if self.module.training:
- output = self.module.training_step(*inputs, **kwargs)
- elif self.module.testing:
- output = self.module.test_step(*inputs, **kwargs)
- else:
- output = self.module.validation_step(*inputs, **kwargs)
- return output
|