partitioned_param_profiler.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from dataclasses import dataclass
  5. from deepspeed.utils import log_dist
  6. class PartitionedParameterProfiler(object):
  7. @dataclass
  8. class EventCounter:
  9. name: str
  10. count: int
  11. num_elem: int
  12. def reset(self):
  13. self.count = 0
  14. self.num_elem = 0
  15. def increment(self, numel):
  16. self.count += 1
  17. self.num_elem += numel
  18. def __init__(self, timers):
  19. self.timers = timers
  20. self.event_counters = {}
  21. def reset_events(self):
  22. for event_ctr in self.event_counters.values():
  23. event_ctr.reset()
  24. def start_event(self, name):
  25. if self.timers is None:
  26. return
  27. if name not in self.event_counters:
  28. self.event_counters[name] = __class__.EventCounter(name=name, count=0, num_elem=0)
  29. self.timers(name).start()
  30. def stop_event(self, name, num_elem):
  31. if self.timers is None:
  32. return
  33. assert name in self.event_counters, f'unknown event {name}'
  34. self.event_counters[name].increment(num_elem)
  35. self.timers(name).stop()
  36. def _log_timers(self):
  37. if self.timers is None:
  38. return
  39. self.timers.log(names=list(self.event_counters.keys()))
  40. def _log_event_counters(self):
  41. for event_ctr in self.event_counters.values():
  42. log_dist(
  43. f'{event_ctr.name}: count = {event_ctr.count}, numel = {event_ctr.num_elem}',
  44. #f'{event_ctr.name}: time = {self._log_timers()},count = {event_ctr.count}, numel = {event_ctr.num_elem}',
  45. ranks=[0])
  46. def log_events(self):
  47. self._log_event_counters()
  48. self._log_timers()