123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from dataclasses import dataclass
- from deepspeed.utils import log_dist
- class PartitionedParameterProfiler(object):
- @dataclass
- class EventCounter:
- name: str
- count: int
- num_elem: int
- def reset(self):
- self.count = 0
- self.num_elem = 0
- def increment(self, numel):
- self.count += 1
- self.num_elem += numel
- def __init__(self, timers):
- self.timers = timers
- self.event_counters = {}
- def reset_events(self):
- for event_ctr in self.event_counters.values():
- event_ctr.reset()
- def start_event(self, name):
- if self.timers is None:
- return
- if name not in self.event_counters:
- self.event_counters[name] = __class__.EventCounter(name=name, count=0, num_elem=0)
- self.timers(name).start()
- def stop_event(self, name, num_elem):
- if self.timers is None:
- return
- assert name in self.event_counters, f'unknown event {name}'
- self.event_counters[name].increment(num_elem)
- self.timers(name).stop()
- def _log_timers(self):
- if self.timers is None:
- return
- self.timers.log(names=list(self.event_counters.keys()))
- def _log_event_counters(self):
- for event_ctr in self.event_counters.values():
- log_dist(
- f'{event_ctr.name}: count = {event_ctr.count}, numel = {event_ctr.num_elem}',
- #f'{event_ctr.name}: time = {self._log_timers()},count = {event_ctr.count}, numel = {event_ctr.num_elem}',
- ranks=[0])
- def log_events(self):
- self._log_event_counters()
- self._log_timers()
|