logging.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import logging
  3. import sys
  4. import os
  5. log_levels = {
  6. "debug": logging.DEBUG,
  7. "info": logging.INFO,
  8. "warning": logging.WARNING,
  9. "error": logging.ERROR,
  10. "critical": logging.CRITICAL,
  11. }
  12. class LoggerFactory:
  13. @staticmethod
  14. def create_logger(name=None, level=logging.INFO):
  15. """create a logger
  16. Args:
  17. name (str): name of the logger
  18. level: level of logger
  19. Raises:
  20. ValueError is name is None
  21. """
  22. if name is None:
  23. raise ValueError("name for logger cannot be None")
  24. formatter = logging.Formatter(
  25. "[%(asctime)s] [%(levelname)s] "
  26. "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
  27. logger_ = logging.getLogger(name)
  28. logger_.setLevel(level)
  29. logger_.propagate = False
  30. ch = logging.StreamHandler(stream=sys.stdout)
  31. ch.setLevel(level)
  32. ch.setFormatter(formatter)
  33. logger_.addHandler(ch)
  34. return logger_
  35. logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO)
  36. def print_configuration(args, name):
  37. logger.info("{}:".format(name))
  38. for arg in sorted(vars(args)):
  39. dots = "." * (29 - len(arg))
  40. logger.info(" {} {} {}".format(arg, dots, getattr(args, arg)))
  41. def log_dist(message, ranks=None, level=logging.INFO):
  42. from deepspeed import comm as dist
  43. """Log message when one of following condition meets
  44. + not dist.is_initialized()
  45. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  46. Args:
  47. message (str)
  48. ranks (list)
  49. level (int)
  50. """
  51. should_log = not dist.is_initialized()
  52. ranks = ranks or []
  53. my_rank = dist.get_rank() if dist.is_initialized() else -1
  54. if ranks and not should_log:
  55. should_log = ranks[0] == -1
  56. should_log = should_log or (my_rank in set(ranks))
  57. if should_log:
  58. final_message = "[Rank {}] {}".format(my_rank, message)
  59. logger.log(level, final_message)
  60. def print_json_dist(message, ranks=None, path=None):
  61. from deepspeed import comm as dist
  62. """Print message when one of following condition meets
  63. + not dist.is_initialized()
  64. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  65. Args:
  66. message (str)
  67. ranks (list)
  68. path (str)
  69. """
  70. should_log = not dist.is_initialized()
  71. ranks = ranks or []
  72. my_rank = dist.get_rank() if dist.is_initialized() else -1
  73. if ranks and not should_log:
  74. should_log = ranks[0] == -1
  75. should_log = should_log or (my_rank in set(ranks))
  76. if should_log:
  77. message['rank'] = my_rank
  78. import json
  79. with open(path, 'w') as outfile:
  80. json.dump(message, outfile)
  81. os.fsync(outfile)
  82. def get_current_level():
  83. """
  84. Return logger's current log level
  85. """
  86. return logger.getEffectiveLevel()
  87. def should_log_le(max_log_level_str):
  88. """
  89. Args:
  90. max_log_level_str: maximum log level as a string
  91. Returns ``True`` if the current log_level is less or equal to the specified log level. Otherwise ``False``.
  92. Example:
  93. ``should_log_le("info")`` will return ``True`` if the current log level is either ``logging.INFO`` or ``logging.DEBUG``
  94. """
  95. if not isinstance(max_log_level_str, str):
  96. raise ValueError(f"{max_log_level_str} is not a string")
  97. max_log_level_str = max_log_level_str.lower()
  98. if max_log_level_str not in log_levels:
  99. raise ValueError(f"{max_log_level_str} is not one of the `logging` levels")
  100. return get_current_level() <= log_levels[max_log_level_str]