123456789101112131415161718192021222324 |
- '''
- Copyright 2023 The Microsoft DeepSpeed Team
- '''
- from abc import ABC, abstractmethod
- class CUDAGraph(ABC):
- def __init__(self, enable_cuda_graph=False):
- super().__init__()
- self.enable_cuda_graph = enable_cuda_graph
- @abstractmethod
- def _create_cuda_graph(self):
- """
- Create CUDA graph(s)
- """
- raise NotImplementedError
- @abstractmethod
- def _graph_replay(self):
- """
- Replay CUDA graph(s)
- """
- raise NotImplementedError
|