timer.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. """
  4. from numpy.core.numeric import count_nonzero
  5. from deepspeed.elasticity.elasticity import compute_elastic_config
  6. import time
  7. import torch
  8. from numpy import mean
  9. from deepspeed.utils.logging import log_dist
  10. from deepspeed.utils import logger
  11. try:
  12. import psutil
  13. PSUTILS_INSTALLED = True
  14. except ImportError:
  15. PSUTILS_INSTALLED = False
  16. pass
  17. class SynchronizedWallClockTimer:
  18. """Group of timers. Borrowed from Nvidia Megatron code"""
  19. class Timer:
  20. """Timer."""
  21. def __init__(self, name):
  22. self.name_ = name
  23. self.elapsed_ = 0.0
  24. self.started_ = False
  25. self.start_time = time.time()
  26. self.records = []
  27. def start(self):
  28. """Start the timer."""
  29. assert not self.started_, "timer has already been started"
  30. torch.cuda.synchronize()
  31. self.start_time = time.time()
  32. self.started_ = True
  33. def stop(self, reset=False, record=False):
  34. """Stop the timer."""
  35. assert self.started_, "timer is not started"
  36. torch.cuda.synchronize()
  37. if reset:
  38. self.elapsed_ = time.time() - self.start_time
  39. else:
  40. self.elapsed_ += time.time() - self.start_time
  41. self.started_ = False
  42. if record:
  43. self.records.append(self.elapsed_)
  44. def reset(self):
  45. """Reset timer."""
  46. self.elapsed_ = 0.0
  47. self.started_ = False
  48. self.acc_ = 0.0
  49. self.cnt_ = 0
  50. def elapsed(self, reset=True):
  51. """Calculate the elapsed time."""
  52. started_ = self.started_
  53. # If the timing in progress, end it first.
  54. if self.started_:
  55. self.stop()
  56. # Get the elapsed time.
  57. elapsed_ = self.elapsed_
  58. # Reset the elapsed time
  59. if reset:
  60. self.reset()
  61. # If timing was in progress, set it back.
  62. if started_:
  63. self.start()
  64. return elapsed_
  65. def mean(self):
  66. return trim_mean(self.records, 0.1)
  67. def __init__(self):
  68. self.timers = {}
  69. def __call__(self, name):
  70. if name not in self.timers:
  71. self.timers[name] = self.Timer(name)
  72. return self.timers[name]
  73. @staticmethod
  74. def memory_usage():
  75. alloc = "mem_allocated: {:.4f} GB".format(torch.cuda.memory_allocated() /
  76. (1024 * 1024 * 1024))
  77. max_alloc = "max_mem_allocated: {:.4f} GB".format(
  78. torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024))
  79. cache = "cache_allocated: {:.4f} GB".format(torch.cuda.memory_cached() /
  80. (1024 * 1024 * 1024))
  81. max_cache = "max_cache_allocated: {:.4f} GB".format(
  82. torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))
  83. return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
  84. def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
  85. """Log a group of timers."""
  86. assert normalizer > 0.0
  87. string = f"rank={torch.distributed.get_rank()} time (ms)"
  88. for name in names:
  89. if name in self.timers:
  90. elapsed_time = (self.timers[name].elapsed(reset=reset) * 1000.0 /
  91. normalizer)
  92. string += " | {}: {:.2f}".format(name, elapsed_time)
  93. log_dist(string, ranks=ranks or [0])
  94. def get_mean(self, names, normalizer=1.0, reset=True):
  95. """Get the mean of a group of timers."""
  96. assert normalizer > 0.0
  97. means = {}
  98. for name in names:
  99. if name in self.timers:
  100. elapsed_time = (self.timers[name].mean() * 1000.0 / normalizer)
  101. means[name] = elapsed_time
  102. return means
  103. class ThroughputTimer:
  104. def __init__(
  105. self,
  106. batch_size,
  107. num_workers,
  108. start_step=2,
  109. steps_per_output=50,
  110. monitor_memory=False,
  111. logging_fn=None,
  112. ):
  113. self.start_time = 0
  114. self.end_time = 0
  115. self.started = False
  116. self.batch_size = batch_size
  117. if batch_size is None:
  118. self.batch_size = 1
  119. self.num_workers = num_workers
  120. self.start_step = start_step
  121. self.epoch_count = 0
  122. self.local_step_count = 0
  123. self.total_step_count = 0
  124. self.total_elapsed_time = 0
  125. self.steps_per_output = steps_per_output
  126. self.monitor_memory = monitor_memory
  127. self.logging = logging_fn
  128. if self.logging is None:
  129. self.logging = logger.info
  130. self.initialized = False
  131. if self.monitor_memory and not PSUTILS_INSTALLED:
  132. raise ImportError("Unable to import 'psutils', please install package")
  133. def update_epoch_count(self):
  134. self.epoch_count += 1
  135. self.local_step_count = 0
  136. def _init_timer(self):
  137. self.initialized = True
  138. def start(self):
  139. self._init_timer()
  140. self.started = True
  141. if self.total_step_count >= self.start_step:
  142. torch.cuda.synchronize()
  143. self.start_time = time.time()
  144. def stop(self, report_speed=True):
  145. if not self.started:
  146. return
  147. self.started = False
  148. self.total_step_count += 1
  149. self.local_step_count += 1
  150. if self.total_step_count > self.start_step:
  151. torch.cuda.synchronize()
  152. self.end_time = time.time()
  153. duration = self.end_time - self.start_time
  154. self.total_elapsed_time += duration
  155. if self.local_step_count % self.steps_per_output == 0:
  156. if report_speed:
  157. self.logging("{}/{}, SamplesPerSec={}".format(
  158. self.epoch_count,
  159. self.local_step_count,
  160. self.avg_samples_per_sec(),
  161. ))
  162. if self.monitor_memory:
  163. virt_mem = psutil.virtual_memory()
  164. swap = psutil.swap_memory()
  165. self.logging("{}/{}, vm percent: {}, swap percent: {}".format(
  166. self.epoch_count,
  167. self.local_step_count,
  168. virt_mem.percent,
  169. swap.percent,
  170. ))
  171. def avg_samples_per_sec(self):
  172. if self.total_step_count > 0:
  173. samples_per_step = self.batch_size * self.num_workers
  174. total_step_offset = self.total_step_count - self.start_step
  175. avg_time_per_step = self.total_elapsed_time / total_step_offset
  176. # training samples per second
  177. return samples_per_step / avg_time_per_step
  178. return float("-inf")
  179. def trim_mean(data, trim_percent):
  180. """Compute the trimmed mean of a list of numbers.
  181. Args:
  182. data (list): List of numbers.
  183. trim_percent (float): Percentage of data to trim.
  184. Returns:
  185. float: Trimmed mean.
  186. """
  187. assert trim_percent >= 0.0 and trim_percent <= 1.0
  188. n = len(data)
  189. data.sort()
  190. k = int(round(n * (trim_percent)))
  191. return mean(data[k:n - k])