timer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import time
  5. from numpy import mean
  6. from deepspeed.utils.logging import log_dist
  7. from deepspeed.accelerator import get_accelerator
  8. FORWARD_MICRO_TIMER = 'fwd_microstep'
  9. FORWARD_GLOBAL_TIMER = 'fwd'
  10. BACKWARD_MICRO_TIMER = 'bwd_microstep'
  11. BACKWARD_GLOBAL_TIMER = 'bwd'
  12. BACKWARD_INNER_MICRO_TIMER = 'bwd_inner_microstep'
  13. BACKWARD_INNER_GLOBAL_TIMER = 'bwd_inner'
  14. BACKWARD_REDUCE_MICRO_TIMER = 'bwd_allreduce_microstep'
  15. BACKWARD_REDUCE_GLOBAL_TIMER = 'bwd_allreduce'
  16. STEP_MICRO_TIMER = 'step_microstep'
  17. STEP_GLOBAL_TIMER = 'step'
  18. try:
  19. import psutil
  20. PSUTILS_INSTALLED = True
  21. except ImportError:
  22. PSUTILS_INSTALLED = False
  23. pass
  24. class CudaEventTimer(object):
  25. def __init__(self, start_event: get_accelerator().Event, end_event: get_accelerator().Event):
  26. self.start_event = start_event
  27. self.end_event = end_event
  28. def get_elapsed_msec(self):
  29. get_accelerator().current_stream().wait_event(self.end_event)
  30. self.end_event.synchronize()
  31. return self.start_event.elapsed_time(self.end_event)
  32. class SynchronizedWallClockTimer:
  33. """Group of timers. Borrowed from Nvidia Megatron code"""
  34. class Timer:
  35. """Timer."""
  36. def __init__(self, name):
  37. self.name_ = name
  38. self.started_ = False
  39. self.event_timers = []
  40. self.use_host_timer = get_accelerator().is_synchronized_device()
  41. self.start_event = None
  42. self.elapsed_records = None
  43. self.start_time = 0.0
  44. self.end_time = 0.0
  45. def start(self):
  46. """Start the timer."""
  47. assert not self.started_, f"{self.name_} timer has already been started"
  48. if self.use_host_timer:
  49. self.start_time = time.time()
  50. else:
  51. event_class = get_accelerator().Event
  52. self.start_event = event_class(enable_timing=True)
  53. self.start_event.record()
  54. self.started_ = True
  55. def stop(self, reset=False, record=False):
  56. """Stop the timer."""
  57. assert self.started_, "timer is not started"
  58. event_class = get_accelerator().Event
  59. if self.use_host_timer:
  60. self.end_time = time.time()
  61. self.event_timers.append(self.end_time - self.start_time)
  62. else:
  63. event_class = get_accelerator().Event
  64. end_event = event_class(enable_timing=True)
  65. end_event.record()
  66. self.event_timers.append(CudaEventTimer(self.start_event, end_event))
  67. self.start_event = None
  68. self.started_ = False
  69. def _get_elapsed_msec(self):
  70. if self.use_host_timer:
  71. self.elapsed_records = [et * 1000.0 for et in self.event_timers]
  72. else:
  73. self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
  74. self.event_timers.clear()
  75. return sum(self.elapsed_records)
  76. def reset(self):
  77. """Reset timer."""
  78. self.started_ = False
  79. self.start_event = None
  80. self.elapsed_records = None
  81. self.event_timers.clear()
  82. def elapsed(self, reset=True):
  83. """Calculate the elapsed time."""
  84. started_ = self.started_
  85. # If the timing in progress, end it first.
  86. if self.started_:
  87. self.stop()
  88. # Get the elapsed time.
  89. elapsed_ = self._get_elapsed_msec()
  90. # Reset the elapsed time
  91. if reset:
  92. self.reset()
  93. # If timing was in progress, set it back.
  94. if started_:
  95. self.start()
  96. return elapsed_
  97. def mean(self):
  98. self.elapsed(reset=False)
  99. return trim_mean(self.elapsed_records, 0.1)
  100. def __init__(self):
  101. self.timers = {}
  102. def get_timers(self):
  103. return self.timers
  104. def __call__(self, name):
  105. if name not in self.timers:
  106. self.timers[name] = self.Timer(name)
  107. return self.timers[name]
  108. @staticmethod
  109. def memory_usage():
  110. alloc = "mem_allocated: {:.4f} GB".format(get_accelerator().memory_allocated() / (1024 * 1024 * 1024))
  111. max_alloc = "max_mem_allocated: {:.4f} GB".format(get_accelerator().max_memory_allocated() /
  112. (1024 * 1024 * 1024))
  113. cache = "cache_allocated: {:.4f} GB".format(get_accelerator().memory_cached() / (1024 * 1024 * 1024))
  114. max_cache = "max_cache_allocated: {:.4f} GB".format(get_accelerator().max_memory_cached() /
  115. (1024 * 1024 * 1024))
  116. return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
  117. def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
  118. """Log a group of timers."""
  119. assert normalizer > 0.0
  120. string = f"time (ms)"
  121. for name in names:
  122. if name in self.timers:
  123. elapsed_time = (self.timers[name].elapsed(reset=reset) / normalizer)
  124. string += " | {}: {:.2f}".format(name, elapsed_time)
  125. log_dist(string, ranks=ranks or [0])
  126. def get_mean(self, names, normalizer=1.0, reset=True):
  127. """Get the mean of a group of timers."""
  128. assert normalizer > 0.0
  129. means = {}
  130. for name in names:
  131. if name in self.timers:
  132. elapsed_time = (self.timers[name].mean() * 1000.0 / normalizer)
  133. means[name] = elapsed_time
  134. return means
  135. class NoopTimer:
  136. class Timer:
  137. def start(self):
  138. ...
  139. def reset(self):
  140. ...
  141. def stop(self, **kwargs):
  142. ...
  143. def elapsed(self, **kwargs):
  144. return 0
  145. def mean(self):
  146. return 0
  147. def __init__(self):
  148. self.timer = self.Timer()
  149. def __call__(self, name):
  150. return self.timer
  151. def get_timers(self):
  152. return {}
  153. def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
  154. ...
  155. def get_mean(self, names, normalizer=1.0, reset=True):
  156. ...
  157. class ThroughputTimer:
  158. def __init__(
  159. self,
  160. batch_size,
  161. start_step=2,
  162. steps_per_output=50,
  163. monitor_memory=False,
  164. logging_fn=None,
  165. ):
  166. from deepspeed.utils import logger
  167. self.start_time = 0
  168. self.end_time = 0
  169. self.started = False
  170. self.batch_size = 1 if batch_size is None else batch_size
  171. self.start_step = start_step
  172. self.epoch_count = 0
  173. self.micro_step_count = 0
  174. self.global_step_count = 0
  175. self.total_elapsed_time = 0
  176. self.step_elapsed_time = 0
  177. self.steps_per_output = steps_per_output
  178. self.monitor_memory = monitor_memory
  179. self.logging = logging_fn
  180. if self.logging is None:
  181. self.logging = logger.info
  182. self.initialized = False
  183. if self.monitor_memory and not PSUTILS_INSTALLED:
  184. raise ImportError("Unable to import 'psutils', please install package")
  185. def update_epoch_count(self):
  186. self.epoch_count += 1
  187. self.micro_step_count = 0
  188. def _init_timer(self):
  189. self.initialized = True
  190. def start(self):
  191. self._init_timer()
  192. self.started = True
  193. if self.global_step_count >= self.start_step:
  194. get_accelerator().synchronize()
  195. self.start_time = time.time()
  196. def stop(self, global_step=False, report_speed=True):
  197. if not self.started:
  198. return
  199. self.started = False
  200. self.micro_step_count += 1
  201. if global_step:
  202. self.global_step_count += 1
  203. if self.start_time > 0:
  204. get_accelerator().synchronize()
  205. self.end_time = time.time()
  206. duration = self.end_time - self.start_time
  207. self.total_elapsed_time += duration
  208. self.step_elapsed_time += duration
  209. if global_step:
  210. if report_speed and self.global_step_count % self.steps_per_output == 0:
  211. self.logging(
  212. "epoch={}/micro_step={}/global_step={}, RunningAvgSamplesPerSec={}, CurrSamplesPerSec={}, "
  213. "MemAllocated={}GB, MaxMemAllocated={}GB".format(
  214. self.epoch_count,
  215. self.micro_step_count,
  216. self.global_step_count,
  217. self.avg_samples_per_sec(),
  218. self.batch_size / self.step_elapsed_time,
  219. round(get_accelerator().memory_allocated() / 1024**3, 2),
  220. round(get_accelerator().max_memory_allocated() / 1024**3, 2),
  221. ))
  222. if self.monitor_memory:
  223. virt_mem = psutil.virtual_memory()
  224. swap = psutil.swap_memory()
  225. self.logging("epoch={}/micro_step={}/global_step={}, vm %: {}, swap %: {}".format(
  226. self.epoch_count,
  227. self.micro_step_count,
  228. self.global_step_count,
  229. virt_mem.percent,
  230. swap.percent,
  231. ))
  232. self.step_elapsed_time = 0
  233. def avg_samples_per_sec(self):
  234. if self.global_step_count > 0:
  235. total_step_offset = self.global_step_count - self.start_step
  236. avg_time_per_step = self.total_elapsed_time / total_step_offset
  237. # training samples per second
  238. return self.batch_size / avg_time_per_step
  239. return float("-inf")
  240. def trim_mean(data, trim_percent):
  241. """Compute the trimmed mean of a list of numbers.
  242. Args:
  243. data (list): List of numbers.
  244. trim_percent (float): Percentage of data to trim.
  245. Returns:
  246. float: Trimmed mean.
  247. """
  248. assert 0.0 <= trim_percent <= 1.0
  249. n = len(data)
  250. # Account for edge case of empty list
  251. if len(data) == 0:
  252. return 0
  253. data.sort()
  254. k = int(round(n * (trim_percent)))
  255. return mean(data[k:n - k])