vae.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. '''
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. from ..features.cuda_graph import CUDAGraph
  6. class DSVAE(CUDAGraph, torch.nn.Module):
  7. def __init__(self, vae, enable_cuda_graph=True):
  8. super().__init__(enable_cuda_graph=enable_cuda_graph)
  9. self.vae = vae
  10. self.device = self.vae.device
  11. self.dtype = self.vae.dtype
  12. self.vae.requires_grad_(requires_grad=False)
  13. self.decoder_cuda_graph_created = False
  14. self.encoder_cuda_graph_created = False
  15. self.all_cuda_graph_created = False
  16. def _graph_replay_decoder(self, *inputs, **kwargs):
  17. for i in range(len(inputs)):
  18. if torch.is_tensor(inputs[i]):
  19. self.static_decoder_inputs[i].copy_(inputs[i])
  20. for k in kwargs:
  21. if torch.is_tensor(kwargs[k]):
  22. self.static_decoder_kwargs[k].copy_(kwargs[k])
  23. self._decoder_cuda_graph.replay()
  24. return self.static_decoder_output
  25. def _decode(self, x, return_dict=True):
  26. return self.vae.decode(x, return_dict=return_dict)
  27. def _create_cuda_graph_decoder(self, *inputs, **kwargs):
  28. # warmup to create the workspace and cublas handle
  29. cuda_stream = torch.cuda.Stream()
  30. cuda_stream.wait_stream(torch.cuda.current_stream())
  31. with torch.cuda.stream(cuda_stream):
  32. for i in range(3):
  33. ret = self._decode(*inputs, **kwargs)
  34. torch.cuda.current_stream().wait_stream(cuda_stream)
  35. # create cuda_graph and assign static_inputs and static_outputs
  36. self._decoder_cuda_graph = torch.cuda.CUDAGraph()
  37. self.static_decoder_inputs = inputs
  38. self.static_decoder_kwargs = kwargs
  39. with torch.cuda.graph(self._decoder_cuda_graph):
  40. self.static_decoder_output = self._decode(*self.static_decoder_inputs,
  41. **self.static_decoder_kwargs)
  42. self.decoder_cuda_graph_created = True
  43. def decode(self, *inputs, **kwargs):
  44. if self.enable_cuda_graph:
  45. if self.decoder_cuda_graph_created:
  46. outputs = self._graph_replay_decoder(*inputs, **kwargs)
  47. else:
  48. self._create_cuda_graph_decoder(*inputs, **kwargs)
  49. outputs = self._graph_replay_decoder(*inputs, **kwargs)
  50. return outputs
  51. else:
  52. return self._decode(*inputs, **kwargs)
  53. def _graph_replay_encoder(self, *inputs, **kwargs):
  54. for i in range(len(inputs)):
  55. if torch.is_tensor(inputs[i]):
  56. self.static_encoder_inputs[i].copy_(inputs[i])
  57. for k in kwargs:
  58. if torch.is_tensor(kwargs[k]):
  59. self.static_encoder_kwargs[k].copy_(kwargs[k])
  60. self._encoder_cuda_graph.replay()
  61. return self.static_encoder_output
  62. def _encode(self, x, return_dict=True):
  63. return self.vae.encode(x, return_dict=return_dict)
  64. def _create_cuda_graph_encoder(self, *inputs, **kwargs):
  65. # warmup to create the workspace and cublas handle
  66. cuda_stream = torch.cuda.Stream()
  67. cuda_stream.wait_stream(torch.cuda.current_stream())
  68. with torch.cuda.stream(cuda_stream):
  69. for i in range(3):
  70. ret = self._encode(*inputs, **kwargs)
  71. torch.cuda.current_stream().wait_stream(cuda_stream)
  72. # create cuda_graph and assign static_inputs and static_outputs
  73. self._encoder_cuda_graph = torch.cuda.CUDAGraph()
  74. self.static_encoder_inputs = inputs
  75. self.static_encoder_kwargs = kwargs
  76. with torch.cuda.graph(self._encoder_cuda_graph):
  77. self.static_encoder_output = self._encode(*self.static_encoder_inputs,
  78. **self.static_encoder_kwargs)
  79. self.encoder_cuda_graph_created = True
  80. def encode(self, *inputs, **kwargs):
  81. if self.enable_cuda_graph:
  82. if self.encoder_cuda_graph_created:
  83. outputs = self._graph_replay_encoder(*inputs, **kwargs)
  84. else:
  85. self._create_cuda_graph_encoder(*inputs, **kwargs)
  86. outputs = self._graph_replay_encoder(*inputs, **kwargs)
  87. return outputs
  88. else:
  89. return self._encode(*inputs, **kwargs)
  90. def _graph_replay(self, *inputs, **kwargs):
  91. for i in range(len(inputs)):
  92. if torch.is_tensor(inputs[i]):
  93. self.static_inputs[i].copy_(inputs[i])
  94. for k in kwargs:
  95. if torch.is_tensor(kwargs[k]):
  96. self.static_kwargs[k].copy_(kwargs[k])
  97. self._all_cuda_graph.replay()
  98. return self.static_output
  99. def forward(self, *inputs, **kwargs):
  100. if self.enable_cuda_graph:
  101. if self.cuda_graph_created:
  102. outputs = self._graph_replay(*inputs, **kwargs)
  103. else:
  104. self._create_cuda_graph(*inputs, **kwargs)
  105. outputs = self._graph_replay(*inputs, **kwargs)
  106. return outputs
  107. else:
  108. return self._forward(*inputs, **kwargs)
  109. def _create_cuda_graph(self, *inputs, **kwargs):
  110. # warmup to create the workspace and cublas handle
  111. cuda_stream = torch.cuda.Stream()
  112. cuda_stream.wait_stream(torch.cuda.current_stream())
  113. with torch.cuda.stream(cuda_stream):
  114. for i in range(3):
  115. ret = self._forward(*inputs, **kwargs)
  116. torch.cuda.current_stream().wait_stream(cuda_stream)
  117. # create cuda_graph and assign static_inputs and static_outputs
  118. self._all_cuda_graph = torch.cuda.CUDAGraph()
  119. self.static_inputs = inputs
  120. self.static_kwargs = kwargs
  121. with torch.cuda.graph(self._all_cuda_graph):
  122. self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
  123. self.all_cuda_graph_created = True
  124. def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
  125. return self.vae(sample, timestamp, encoder_hidden_states, return_dict)