logging.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import logging
  2. import sys
  3. import os
  4. import torch.distributed as dist
  5. log_levels = {
  6. "debug": logging.DEBUG,
  7. "info": logging.INFO,
  8. "warning": logging.WARNING,
  9. "error": logging.ERROR,
  10. "critical": logging.CRITICAL,
  11. }
  12. class LoggerFactory:
  13. @staticmethod
  14. def create_logger(name=None, level=logging.INFO):
  15. """create a logger
  16. Args:
  17. name (str): name of the logger
  18. level: level of logger
  19. Raises:
  20. ValueError is name is None
  21. """
  22. if name is None:
  23. raise ValueError("name for logger cannot be None")
  24. formatter = logging.Formatter(
  25. "[%(asctime)s] [%(levelname)s] "
  26. "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
  27. logger_ = logging.getLogger(name)
  28. logger_.setLevel(level)
  29. logger_.propagate = False
  30. ch = logging.StreamHandler(stream=sys.stdout)
  31. ch.setLevel(level)
  32. ch.setFormatter(formatter)
  33. logger_.addHandler(ch)
  34. return logger_
  35. logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO)
  36. def log_dist(message, ranks=None, level=logging.INFO):
  37. """Log message when one of following condition meets
  38. + not dist.is_initialized()
  39. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  40. Args:
  41. message (str)
  42. ranks (list)
  43. level (int)
  44. """
  45. should_log = not dist.is_initialized()
  46. ranks = ranks or []
  47. my_rank = dist.get_rank() if dist.is_initialized() else -1
  48. if ranks and not should_log:
  49. should_log = ranks[0] == -1
  50. should_log = should_log or (my_rank in set(ranks))
  51. if should_log:
  52. final_message = "[Rank {}] {}".format(my_rank, message)
  53. logger.log(level, final_message)
  54. def print_json_dist(message, ranks=None, path=None):
  55. """Print message when one of following condition meets
  56. + not dist.is_initialized()
  57. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  58. Args:
  59. message (str)
  60. ranks (list)
  61. path (str)
  62. """
  63. should_log = not dist.is_initialized()
  64. ranks = ranks or []
  65. my_rank = dist.get_rank() if dist.is_initialized() else -1
  66. if ranks and not should_log:
  67. should_log = ranks[0] == -1
  68. should_log = should_log or (my_rank in set(ranks))
  69. if should_log:
  70. message['rank'] = my_rank
  71. import json
  72. with open(path, 'w') as outfile:
  73. json.dump(message, outfile)
  74. os.fsync(outfile)
  75. def get_current_level():
  76. """
  77. Return logger's current log level
  78. """
  79. return logger.getEffectiveLevel()
  80. def should_log_le(max_log_level_str):
  81. """
  82. Args:
  83. max_log_level_str: maximum log level as a string
  84. Returns ``True`` if the current log_level is less or equal to the specified log level. Otherwise ``False``.
  85. Example:
  86. ``should_log_le("info")`` will return ``True`` if the current log level is either ``logging.INFO`` or ``logging.DEBUG``
  87. """
  88. if not isinstance(max_log_level_str, str):
  89. raise ValueError(f"{max_log_level_str} is not a string")
  90. max_log_level_str = max_log_level_str.lower()
  91. if max_log_level_str not in log_levels:
  92. raise ValueError(f"{max_log_level_str} is not one of the `logging` levels")
  93. return get_current_level() <= log_levels[max_log_level_str]