abstract_accelerator.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import abc
  5. from abc import ABC
  6. class DeepSpeedAccelerator(ABC):
  7. def __init__(self):
  8. self._name = None
  9. self._communication_backend_name = None
  10. # Device APIs
  11. @abc.abstractmethod
  12. def device_name(self, device_index):
  13. ...
  14. @abc.abstractmethod
  15. def device(self, device_index):
  16. ...
  17. @abc.abstractmethod
  18. def set_device(self, device_index):
  19. ...
  20. @abc.abstractmethod
  21. def current_device(self):
  22. ...
  23. @abc.abstractmethod
  24. def current_device_name(self):
  25. ...
  26. @abc.abstractmethod
  27. def device_count(self):
  28. ...
  29. @abc.abstractmethod
  30. def synchronize(self, device_index=None):
  31. ...
  32. # RNG APIs
  33. @abc.abstractmethod
  34. def random(self):
  35. ...
  36. @abc.abstractmethod
  37. def set_rng_state(self, new_state, device_index=None):
  38. ...
  39. @abc.abstractmethod
  40. def get_rng_state(self, device_index=None):
  41. ...
  42. @abc.abstractmethod
  43. def manual_seed(self, seed):
  44. ...
  45. @abc.abstractmethod
  46. def manual_seed_all(self, seed):
  47. ...
  48. @abc.abstractmethod
  49. def initial_seed(self, seed):
  50. ...
  51. @abc.abstractmethod
  52. def default_generator(self, device_index):
  53. ...
  54. # Streams/Events
  55. @property
  56. @abc.abstractmethod
  57. def Stream(self):
  58. ...
  59. @abc.abstractmethod
  60. def stream(self, stream):
  61. ...
  62. @abc.abstractmethod
  63. def current_stream(self, device_index=None):
  64. ...
  65. @abc.abstractmethod
  66. def default_stream(self, device_index=None):
  67. ...
  68. @property
  69. @abc.abstractmethod
  70. def Event(self):
  71. ...
  72. # Memory management
  73. @abc.abstractmethod
  74. def empty_cache(self):
  75. ...
  76. @abc.abstractmethod
  77. def memory_allocated(self, device_index=None):
  78. ...
  79. @abc.abstractmethod
  80. def max_memory_allocated(self, device_index=None):
  81. ...
  82. @abc.abstractmethod
  83. def reset_max_memory_allocated(self, device_index=None):
  84. ...
  85. @abc.abstractmethod
  86. def memory_cached(self, device_index=None):
  87. ...
  88. @abc.abstractmethod
  89. def max_memory_cached(self, device_index=None):
  90. ...
  91. @abc.abstractmethod
  92. def reset_max_memory_cached(self, device_index=None):
  93. ...
  94. @abc.abstractmethod
  95. def memory_stats(self, device_index=None):
  96. ...
  97. @abc.abstractmethod
  98. def reset_peak_memory_stats(self, device_index=None):
  99. ...
  100. @abc.abstractmethod
  101. def memory_reserved(self, device_index=None):
  102. ...
  103. @abc.abstractmethod
  104. def max_memory_reserved(self, device_index=None):
  105. ...
  106. @abc.abstractmethod
  107. def total_memory(self, device_index=None):
  108. ...
  109. # Data types
  110. @abc.abstractmethod
  111. def is_bf16_supported(self):
  112. ...
  113. @abc.abstractmethod
  114. def is_fp16_supported(self):
  115. ...
  116. # Misc
  117. @abc.abstractmethod
  118. def amp(self):
  119. ...
  120. @abc.abstractmethod
  121. def is_available(self):
  122. ...
  123. @abc.abstractmethod
  124. def range_push(self, msg):
  125. ...
  126. @abc.abstractmethod
  127. def range_pop(self):
  128. ...
  129. @abc.abstractmethod
  130. def lazy_call(self, callback):
  131. ...
  132. @abc.abstractmethod
  133. def communication_backend_name(self):
  134. ...
  135. # Tensor operations
  136. @property
  137. @abc.abstractmethod
  138. def BFloat16Tensor(self):
  139. ...
  140. @property
  141. @abc.abstractmethod
  142. def ByteTensor(self):
  143. ...
  144. @property
  145. @abc.abstractmethod
  146. def DoubleTensor(self):
  147. ...
  148. @property
  149. @abc.abstractmethod
  150. def FloatTensor(self):
  151. ...
  152. @property
  153. @abc.abstractmethod
  154. def HalfTensor(self):
  155. ...
  156. @property
  157. @abc.abstractmethod
  158. def IntTensor(self):
  159. ...
  160. @property
  161. @abc.abstractmethod
  162. def LongTensor(self):
  163. ...
  164. @abc.abstractmethod
  165. def pin_memory(self, tensor):
  166. ...
  167. @abc.abstractmethod
  168. def on_accelerator(self, tensor):
  169. ...
  170. @abc.abstractmethod
  171. def op_builder_dir(self):
  172. ...
  173. # create an instance of op builder, specified by class_name
  174. @abc.abstractmethod
  175. def create_op_builder(self, class_name):
  176. ...
  177. # return an op builder class, specified by class_name
  178. @abc.abstractmethod
  179. def get_op_builder(self, class_name):
  180. ...
  181. @abc.abstractmethod
  182. def build_extension(self):
  183. ...