clip_encoder.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.accelerator import get_accelerator
  6. from ..features.cuda_graph import CUDAGraph
  7. class DSClipEncoder(CUDAGraph, torch.nn.Module):
  8. def __init__(self, enc, enable_cuda_graph=False):
  9. super().__init__(enable_cuda_graph=enable_cuda_graph)
  10. enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
  11. self.enc = enc
  12. self.device = self.enc.device
  13. self.dtype = self.enc.dtype
  14. self.cuda_graph_created = [False, False]
  15. self.static_inputs = [None, None]
  16. self.static_kwargs = [None, None]
  17. self.static_output = [None, None]
  18. self._cuda_graphs = [None, None]
  19. self.iter = 0
  20. self.config = self.enc.config
  21. def _build_causal_attention_mask(self, bsz, seq_len, dtype):
  22. mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=get_accelerator().current_device_name())
  23. mask.fill_(torch.tensor(torch.finfo(dtype).min))
  24. mask.triu_(1)
  25. mask = mask.unsqueeze(1)
  26. return mask
  27. def _graph_replay(self, *inputs, **kwargs):
  28. for i in range(len(inputs)):
  29. if torch.is_tensor(inputs[i]):
  30. self.static_inputs[self.iter][i].copy_(inputs[i])
  31. for k in kwargs:
  32. if torch.is_tensor(kwargs[k]):
  33. self.static_kwargs[self.iter][k].copy_(kwargs[k])
  34. self._cuda_graphs[self.iter].replay()
  35. return self.static_output[self.iter]
  36. def forward(self, *inputs, **kwargs):
  37. if self.enable_cuda_graph:
  38. if self.cuda_graph_created[self.iter]:
  39. outputs = self._graph_replay(*inputs, **kwargs)
  40. else:
  41. self._create_cuda_graph(*inputs, **kwargs)
  42. outputs = self._graph_replay(*inputs, **kwargs)
  43. self.iter = (self.iter + 1) % 2
  44. return outputs
  45. else:
  46. return self.enc(*inputs, **kwargs)
  47. def _create_cuda_graph(self, *inputs, **kwargs):
  48. # warmup to create the workspace and cublas handle
  49. cuda_stream = torch.cuda.Stream()
  50. cuda_stream.wait_stream(torch.cuda.current_stream())
  51. with torch.cuda.stream(cuda_stream):
  52. for i in range(3):
  53. ret = self._forward(*inputs, **kwargs)
  54. torch.cuda.current_stream().wait_stream(cuda_stream)
  55. # create cuda_graph and assign static_inputs and static_outputs
  56. self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
  57. self.static_inputs[self.iter] = inputs
  58. self.static_kwargs[self.iter] = kwargs
  59. with torch.cuda.graph(self._cuda_graphs[self.iter]):
  60. self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
  61. **self.static_kwargs[self.iter])
  62. self.cuda_graph_created[self.iter] = True
  63. def _forward(self, *inputs, **kwargs):
  64. return self.enc(*inputs, **kwargs)