meters.py 1019 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import time
  2. import torch
  3. class AvgrageMeter(object):
  4. def __init__(self):
  5. self.reset()
  6. def reset(self):
  7. self.avg = 0
  8. self.sum = 0
  9. self.cnt = 0
  10. def update(self, val, n=1):
  11. self.sum += val * n
  12. self.cnt += n
  13. self.avg = self.sum / self.cnt
  14. class Timer:
  15. timer_map = {}
  16. def __init__(self, name, enable=False):
  17. if name not in Timer.timer_map:
  18. Timer.timer_map[name] = 0
  19. self.name = name
  20. self.enable = enable
  21. def __enter__(self):
  22. if self.enable:
  23. # if torch.cuda.is_available():
  24. # torch.cuda.synchronize()
  25. self.t = time.time()
  26. def __exit__(self, exc_type, exc_val, exc_tb):
  27. if self.enable:
  28. # if torch.cuda.is_available():
  29. # torch.cuda.synchronize()
  30. Timer.timer_map[self.name] += time.time() - self.t
  31. if self.enable:
  32. print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}')