cuda_graph.py 518 B

123456789101112131415161718192021222324
  1. '''
  2. Copyright 2023 The Microsoft DeepSpeed Team
  3. '''
  4. from abc import ABC, abstractmethod
  5. class CUDAGraph(ABC):
  6. def __init__(self, enable_cuda_graph=False):
  7. super().__init__()
  8. self.enable_cuda_graph = enable_cuda_graph
  9. @abstractmethod
  10. def _create_cuda_graph(self):
  11. """
  12. Create CUDA graph(s)
  13. """
  14. raise NotImplementedError
  15. @abstractmethod
  16. def _graph_replay(self):
  17. """
  18. Replay CUDA graph(s)
  19. """
  20. raise NotImplementedError