1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from deepspeed.accelerator import get_accelerator
- from ..features.cuda_graph import CUDAGraph
- class DSClipEncoder(CUDAGraph, torch.nn.Module):
- def __init__(self, enc, enable_cuda_graph=False):
- super().__init__(enable_cuda_graph=enable_cuda_graph)
- enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
- self.enc = enc
- self.device = self.enc.device
- self.dtype = self.enc.dtype
- self.cuda_graph_created = [False, False]
- self.static_inputs = [None, None]
- self.static_kwargs = [None, None]
- self.static_output = [None, None]
- self._cuda_graphs = [None, None]
- self.iter = 0
- self.config = self.enc.config
- def _build_causal_attention_mask(self, bsz, seq_len, dtype):
- mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=get_accelerator().current_device_name())
- mask.fill_(torch.tensor(torch.finfo(dtype).min))
- mask.triu_(1)
- mask = mask.unsqueeze(1)
- return mask
- def _graph_replay(self, *inputs, **kwargs):
- for i in range(len(inputs)):
- if torch.is_tensor(inputs[i]):
- self.static_inputs[self.iter][i].copy_(inputs[i])
- for k in kwargs:
- if torch.is_tensor(kwargs[k]):
- self.static_kwargs[self.iter][k].copy_(kwargs[k])
- self._cuda_graphs[self.iter].replay()
- return self.static_output[self.iter]
- def forward(self, *inputs, **kwargs):
- if self.enable_cuda_graph:
- if self.cuda_graph_created[self.iter]:
- outputs = self._graph_replay(*inputs, **kwargs)
- else:
- self._create_cuda_graph(*inputs, **kwargs)
- outputs = self._graph_replay(*inputs, **kwargs)
- self.iter = (self.iter + 1) % 2
- return outputs
- else:
- return self.enc(*inputs, **kwargs)
- def _create_cuda_graph(self, *inputs, **kwargs):
- # warmup to create the workspace and cublas handle
- cuda_stream = torch.cuda.Stream()
- cuda_stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(cuda_stream):
- for i in range(3):
- ret = self._forward(*inputs, **kwargs)
- torch.cuda.current_stream().wait_stream(cuda_stream)
- # create cuda_graph and assign static_inputs and static_outputs
- self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
- self.static_inputs[self.iter] = inputs
- self.static_kwargs[self.iter] = kwargs
- with torch.cuda.graph(self._cuda_graphs[self.iter]):
- self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
- **self.static_kwargs[self.iter])
- self.cuda_graph_created[self.iter] = True
- def _forward(self, *inputs, **kwargs):
- return self.enc(*inputs, **kwargs)
|