vae.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from ..features.cuda_graph import CUDAGraph
  6. class DSVAE(CUDAGraph, torch.nn.Module):
  7. def __init__(self, vae, enable_cuda_graph=True):
  8. super().__init__(enable_cuda_graph=enable_cuda_graph)
  9. self.vae = vae
  10. self.config = vae.config
  11. self.device = self.vae.device
  12. self.dtype = self.vae.dtype
  13. self.vae.requires_grad_(requires_grad=False)
  14. self.decoder_cuda_graph_created = False
  15. self.encoder_cuda_graph_created = False
  16. self.all_cuda_graph_created = False
  17. def _graph_replay_decoder(self, *inputs, **kwargs):
  18. for i in range(len(inputs)):
  19. if torch.is_tensor(inputs[i]):
  20. self.static_decoder_inputs[i].copy_(inputs[i])
  21. for k in kwargs:
  22. if torch.is_tensor(kwargs[k]):
  23. self.static_decoder_kwargs[k].copy_(kwargs[k])
  24. self._decoder_cuda_graph.replay()
  25. return self.static_decoder_output
  26. def _decode(self, x, return_dict=True):
  27. return self.vae.decode(x, return_dict=return_dict)
  28. def _create_cuda_graph_decoder(self, *inputs, **kwargs):
  29. # warmup to create the workspace and cublas handle
  30. cuda_stream = torch.cuda.Stream()
  31. cuda_stream.wait_stream(torch.cuda.current_stream())
  32. with torch.cuda.stream(cuda_stream):
  33. for i in range(3):
  34. ret = self._decode(*inputs, **kwargs)
  35. torch.cuda.current_stream().wait_stream(cuda_stream)
  36. # create cuda_graph and assign static_inputs and static_outputs
  37. self._decoder_cuda_graph = torch.cuda.CUDAGraph()
  38. self.static_decoder_inputs = inputs
  39. self.static_decoder_kwargs = kwargs
  40. with torch.cuda.graph(self._decoder_cuda_graph):
  41. self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)
  42. self.decoder_cuda_graph_created = True
  43. def decode(self, *inputs, **kwargs):
  44. if self.enable_cuda_graph:
  45. if self.decoder_cuda_graph_created:
  46. outputs = self._graph_replay_decoder(*inputs, **kwargs)
  47. else:
  48. self._create_cuda_graph_decoder(*inputs, **kwargs)
  49. outputs = self._graph_replay_decoder(*inputs, **kwargs)
  50. return outputs
  51. else:
  52. return self._decode(*inputs, **kwargs)
  53. def _graph_replay_encoder(self, *inputs, **kwargs):
  54. for i in range(len(inputs)):
  55. if torch.is_tensor(inputs[i]):
  56. self.static_encoder_inputs[i].copy_(inputs[i])
  57. for k in kwargs:
  58. if torch.is_tensor(kwargs[k]):
  59. self.static_encoder_kwargs[k].copy_(kwargs[k])
  60. self._encoder_cuda_graph.replay()
  61. return self.static_encoder_output
  62. def _encode(self, x, return_dict=True):
  63. return self.vae.encode(x, return_dict=return_dict)
  64. def _create_cuda_graph_encoder(self, *inputs, **kwargs):
  65. # warmup to create the workspace and cublas handle
  66. cuda_stream = torch.cuda.Stream()
  67. cuda_stream.wait_stream(torch.cuda.current_stream())
  68. with torch.cuda.stream(cuda_stream):
  69. for i in range(3):
  70. ret = self._encode(*inputs, **kwargs)
  71. torch.cuda.current_stream().wait_stream(cuda_stream)
  72. # create cuda_graph and assign static_inputs and static_outputs
  73. self._encoder_cuda_graph = torch.cuda.CUDAGraph()
  74. self.static_encoder_inputs = inputs
  75. self.static_encoder_kwargs = kwargs
  76. with torch.cuda.graph(self._encoder_cuda_graph):
  77. self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)
  78. self.encoder_cuda_graph_created = True
  79. def encode(self, *inputs, **kwargs):
  80. if self.enable_cuda_graph:
  81. if self.encoder_cuda_graph_created:
  82. outputs = self._graph_replay_encoder(*inputs, **kwargs)
  83. else:
  84. self._create_cuda_graph_encoder(*inputs, **kwargs)
  85. outputs = self._graph_replay_encoder(*inputs, **kwargs)
  86. return outputs
  87. else:
  88. return self._encode(*inputs, **kwargs)
  89. def _graph_replay(self, *inputs, **kwargs):
  90. for i in range(len(inputs)):
  91. if torch.is_tensor(inputs[i]):
  92. self.static_inputs[i].copy_(inputs[i])
  93. for k in kwargs:
  94. if torch.is_tensor(kwargs[k]):
  95. self.static_kwargs[k].copy_(kwargs[k])
  96. self._all_cuda_graph.replay()
  97. return self.static_output
  98. def forward(self, *inputs, **kwargs):
  99. if self.enable_cuda_graph:
  100. if self.cuda_graph_created:
  101. outputs = self._graph_replay(*inputs, **kwargs)
  102. else:
  103. self._create_cuda_graph(*inputs, **kwargs)
  104. outputs = self._graph_replay(*inputs, **kwargs)
  105. return outputs
  106. else:
  107. return self._forward(*inputs, **kwargs)
  108. def _create_cuda_graph(self, *inputs, **kwargs):
  109. # warmup to create the workspace and cublas handle
  110. cuda_stream = torch.cuda.Stream()
  111. cuda_stream.wait_stream(torch.cuda.current_stream())
  112. with torch.cuda.stream(cuda_stream):
  113. for i in range(3):
  114. ret = self._forward(*inputs, **kwargs)
  115. torch.cuda.current_stream().wait_stream(cuda_stream)
  116. # create cuda_graph and assign static_inputs and static_outputs
  117. self._all_cuda_graph = torch.cuda.CUDAGraph()
  118. self.static_inputs = inputs
  119. self.static_kwargs = kwargs
  120. with torch.cuda.graph(self._all_cuda_graph):
  121. self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
  122. self.all_cuda_graph_created = True
  123. def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
  124. return self.vae(sample, timestamp, encoder_hidden_states, return_dict)