timer.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. """
  4. import time
  5. from numpy import mean
  6. from deepspeed.utils.logging import log_dist
  7. from deepspeed.accelerator import get_accelerator
  8. from deepspeed import comm as dist
  9. try:
  10. import psutil
  11. PSUTILS_INSTALLED = True
  12. except ImportError:
  13. PSUTILS_INSTALLED = False
  14. pass
  15. class CudaEventTimer(object):
  16. def __init__(self,
  17. start_event: get_accelerator().Event,
  18. end_event: get_accelerator().Event):
  19. self.start_event = start_event
  20. self.end_event = end_event
  21. def get_elapsed_msec(self):
  22. get_accelerator().current_stream().wait_event(self.end_event)
  23. self.end_event.synchronize()
  24. return self.start_event.elapsed_time(self.end_event)
  25. class SynchronizedWallClockTimer:
  26. """Group of timers. Borrowed from Nvidia Megatron code"""
  27. class Timer:
  28. """Timer."""
  29. def __init__(self, name):
  30. self.name_ = name
  31. self.started_ = False
  32. self.event_timers = []
  33. self.start_event = None
  34. self.elapsed_records = None
  35. def start(self):
  36. """Start the timer."""
  37. assert not self.started_, f"{self.name_} timer has already been started"
  38. self.start_event = get_accelerator().Event(enable_timing=True)
  39. self.start_event.record()
  40. self.started_ = True
  41. def stop(self, reset=False, record=False):
  42. """Stop the timer."""
  43. assert self.started_, "timer is not started"
  44. end_event = get_accelerator().Event(enable_timing=True)
  45. end_event.record()
  46. self.event_timers.append(CudaEventTimer(self.start_event, end_event))
  47. self.start_event = None
  48. self.started_ = False
  49. def _get_elapsed_msec(self):
  50. self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
  51. self.event_timers.clear()
  52. return sum(self.elapsed_records)
  53. def reset(self):
  54. """Reset timer."""
  55. self.started_ = False
  56. self.start_event = None
  57. self.elapsed_records = None
  58. self.event_timers.clear()
  59. def elapsed(self, reset=True):
  60. """Calculate the elapsed time."""
  61. started_ = self.started_
  62. # If the timing in progress, end it first.
  63. if self.started_:
  64. self.stop()
  65. # Get the elapsed time.
  66. elapsed_ = self._get_elapsed_msec()
  67. # Reset the elapsed time
  68. if reset:
  69. self.reset()
  70. # If timing was in progress, set it back.
  71. if started_:
  72. self.start()
  73. return elapsed_
  74. def mean(self):
  75. self.elapsed(reset=False)
  76. return trim_mean(self.elapsed_records, 0.1)
  77. def __init__(self):
  78. self.timers = {}
  79. def get_timers(self):
  80. return self.timers
  81. def __call__(self, name):
  82. if name not in self.timers:
  83. self.timers[name] = self.Timer(name)
  84. return self.timers[name]
  85. @staticmethod
  86. def memory_usage():
  87. alloc = "mem_allocated: {:.4f} GB".format(get_accelerator().memory_allocated() /
  88. (1024 * 1024 * 1024))
  89. max_alloc = "max_mem_allocated: {:.4f} GB".format(
  90. get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024))
  91. cache = "cache_allocated: {:.4f} GB".format(get_accelerator().memory_cached() /
  92. (1024 * 1024 * 1024))
  93. max_cache = "max_cache_allocated: {:.4f} GB".format(
  94. get_accelerator().max_memory_cached() / (1024 * 1024 * 1024))
  95. return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
  96. def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
  97. """Log a group of timers."""
  98. assert normalizer > 0.0
  99. string = f"rank={dist.get_rank()} time (ms)"
  100. for name in names:
  101. if name in self.timers:
  102. elapsed_time = (self.timers[name].elapsed(reset=reset) / normalizer)
  103. string += " | {}: {:.2f}".format(name, elapsed_time)
  104. log_dist(string, ranks=ranks or [0])
  105. def get_mean(self, names, normalizer=1.0, reset=True):
  106. """Get the mean of a group of timers."""
  107. assert normalizer > 0.0
  108. means = {}
  109. for name in names:
  110. if name in self.timers:
  111. elapsed_time = (self.timers[name].mean() * 1000.0 / normalizer)
  112. means[name] = elapsed_time
  113. return means
  114. class ThroughputTimer:
  115. def __init__(
  116. self,
  117. batch_size,
  118. start_step=2,
  119. steps_per_output=50,
  120. monitor_memory=False,
  121. logging_fn=None,
  122. ):
  123. from deepspeed.utils import logger
  124. self.start_time = 0
  125. self.end_time = 0
  126. self.started = False
  127. self.batch_size = 1 if batch_size is None else batch_size
  128. self.start_step = start_step
  129. self.epoch_count = 0
  130. self.micro_step_count = 0
  131. self.global_step_count = 0
  132. self.total_elapsed_time = 0
  133. self.step_elapsed_time = 0
  134. self.steps_per_output = steps_per_output
  135. self.monitor_memory = monitor_memory
  136. self.logging = logging_fn
  137. if self.logging is None:
  138. self.logging = logger.info
  139. self.initialized = False
  140. if self.monitor_memory and not PSUTILS_INSTALLED:
  141. raise ImportError("Unable to import 'psutils', please install package")
  142. def update_epoch_count(self):
  143. self.epoch_count += 1
  144. self.micro_step_count = 0
  145. def _init_timer(self):
  146. self.initialized = True
  147. def start(self):
  148. self._init_timer()
  149. self.started = True
  150. if self.global_step_count >= self.start_step:
  151. get_accelerator().synchronize()
  152. self.start_time = time.time()
  153. def stop(self, global_step=False, report_speed=True):
  154. if not self.started:
  155. return
  156. self.started = False
  157. self.micro_step_count += 1
  158. if global_step:
  159. self.global_step_count += 1
  160. if self.start_time > 0:
  161. get_accelerator().synchronize()
  162. self.end_time = time.time()
  163. duration = self.end_time - self.start_time
  164. self.total_elapsed_time += duration
  165. self.step_elapsed_time += duration
  166. if global_step:
  167. if report_speed and self.global_step_count % self.steps_per_output == 0:
  168. self.logging(
  169. "epoch={}/micro_step={}/global_step={}, RunningAvgSamplesPerSec={}, CurrSamplesPerSec={}, "
  170. "MemAllocated={}GB, MaxMemAllocated={}GB".format(
  171. self.epoch_count,
  172. self.micro_step_count,
  173. self.global_step_count,
  174. self.avg_samples_per_sec(),
  175. self.batch_size / self.step_elapsed_time,
  176. round(get_accelerator().memory_allocated() / 1024**3,
  177. 2),
  178. round(get_accelerator().max_memory_allocated() / 1024**3,
  179. 2),
  180. ))
  181. if self.monitor_memory:
  182. virt_mem = psutil.virtual_memory()
  183. swap = psutil.swap_memory()
  184. self.logging(
  185. "epoch={}/micro_step={}/global_step={}, vm %: {}, swap %: {}"
  186. .format(
  187. self.epoch_count,
  188. self.micro_step_count,
  189. self.global_step_count,
  190. virt_mem.percent,
  191. swap.percent,
  192. ))
  193. self.step_elapsed_time = 0
  194. def avg_samples_per_sec(self):
  195. if self.global_step_count > 0:
  196. total_step_offset = self.global_step_count - self.start_step
  197. avg_time_per_step = self.total_elapsed_time / total_step_offset
  198. # training samples per second
  199. return self.batch_size / avg_time_per_step
  200. return float("-inf")
  201. def trim_mean(data, trim_percent):
  202. """Compute the trimmed mean of a list of numbers.
  203. Args:
  204. data (list): List of numbers.
  205. trim_percent (float): Percentage of data to trim.
  206. Returns:
  207. float: Trimmed mean.
  208. """
  209. assert trim_percent >= 0.0 and trim_percent <= 1.0
  210. n = len(data)
  211. # Account for edge case of empty list
  212. if len(data) == 0:
  213. return 0
  214. data.sort()
  215. k = int(round(n * (trim_percent)))
  216. return mean(data[k:n - k])