unet.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from ..features.cuda_graph import CUDAGraph
  6. class DSUNet(CUDAGraph, torch.nn.Module):
  7. def __init__(self, unet, enable_cuda_graph=True):
  8. super().__init__(enable_cuda_graph=enable_cuda_graph)
  9. self.unet = unet
  10. # SD pipeline accesses this attribute
  11. self.in_channels = unet.in_channels
  12. self.device = self.unet.device
  13. self.dtype = self.unet.dtype
  14. self.config = self.unet.config
  15. self.fwd_count = 0
  16. self.unet.requires_grad_(requires_grad=False)
  17. self.unet.to(memory_format=torch.channels_last)
  18. self.cuda_graph_created = False
  19. def _graph_replay(self, *inputs, **kwargs):
  20. for i in range(len(inputs)):
  21. if torch.is_tensor(inputs[i]):
  22. self.static_inputs[i].copy_(inputs[i])
  23. for k in kwargs:
  24. if torch.is_tensor(kwargs[k]):
  25. self.static_kwargs[k].copy_(kwargs[k])
  26. self._cuda_graphs.replay()
  27. return self.static_output
  28. def forward(self, *inputs, **kwargs):
  29. if self.enable_cuda_graph:
  30. if self.cuda_graph_created:
  31. outputs = self._graph_replay(*inputs, **kwargs)
  32. else:
  33. self._create_cuda_graph(*inputs, **kwargs)
  34. outputs = self._graph_replay(*inputs, **kwargs)
  35. return outputs
  36. else:
  37. return self._forward(*inputs, **kwargs)
  38. def _create_cuda_graph(self, *inputs, **kwargs):
  39. # warmup to create the workspace and cublas handle
  40. cuda_stream = torch.cuda.Stream()
  41. cuda_stream.wait_stream(torch.cuda.current_stream())
  42. with torch.cuda.stream(cuda_stream):
  43. for i in range(3):
  44. ret = self._forward(*inputs, **kwargs)
  45. torch.cuda.current_stream().wait_stream(cuda_stream)
  46. # create cuda_graph and assign static_inputs and static_outputs
  47. self._cuda_graphs = torch.cuda.CUDAGraph()
  48. self.static_inputs = inputs
  49. self.static_kwargs = kwargs
  50. with torch.cuda.graph(self._cuda_graphs):
  51. self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
  52. self.cuda_graph_created = True
  53. def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None):
  54. if cross_attention_kwargs:
  55. return self.unet(sample,
  56. timestamp,
  57. encoder_hidden_states,
  58. return_dict,
  59. cross_attention_kwargs=cross_attention_kwargs)
  60. else:
  61. return self.unet(sample, timestamp, encoder_hidden_states, return_dict)