123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- import logging
- import sys
- import os
- import torch.distributed as dist
- log_levels = {
- "debug": logging.DEBUG,
- "info": logging.INFO,
- "warning": logging.WARNING,
- "error": logging.ERROR,
- "critical": logging.CRITICAL,
- }
- class LoggerFactory:
- @staticmethod
- def create_logger(name=None, level=logging.INFO):
- """create a logger
- Args:
- name (str): name of the logger
- level: level of logger
- Raises:
- ValueError is name is None
- """
- if name is None:
- raise ValueError("name for logger cannot be None")
- formatter = logging.Formatter(
- "[%(asctime)s] [%(levelname)s] "
- "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
- logger_ = logging.getLogger(name)
- logger_.setLevel(level)
- logger_.propagate = False
- ch = logging.StreamHandler(stream=sys.stdout)
- ch.setLevel(level)
- ch.setFormatter(formatter)
- logger_.addHandler(ch)
- return logger_
- logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO)
- def log_dist(message, ranks=None, level=logging.INFO):
- """Log message when one of following condition meets
- + not dist.is_initialized()
- + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
- Args:
- message (str)
- ranks (list)
- level (int)
- """
- should_log = not dist.is_initialized()
- ranks = ranks or []
- my_rank = dist.get_rank() if dist.is_initialized() else -1
- if ranks and not should_log:
- should_log = ranks[0] == -1
- should_log = should_log or (my_rank in set(ranks))
- if should_log:
- final_message = "[Rank {}] {}".format(my_rank, message)
- logger.log(level, final_message)
- def print_json_dist(message, ranks=None, path=None):
- """Print message when one of following condition meets
- + not dist.is_initialized()
- + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
- Args:
- message (str)
- ranks (list)
- path (str)
- """
- should_log = not dist.is_initialized()
- ranks = ranks or []
- my_rank = dist.get_rank() if dist.is_initialized() else -1
- if ranks and not should_log:
- should_log = ranks[0] == -1
- should_log = should_log or (my_rank in set(ranks))
- if should_log:
- message['rank'] = my_rank
- import json
- with open(path, 'w') as outfile:
- json.dump(message, outfile)
- os.fsync(outfile)
- def get_current_level():
- """
- Return logger's current log level
- """
- return logger.getEffectiveLevel()
- def should_log_le(max_log_level_str):
- """
- Args:
- max_log_level_str: maximum log level as a string
- Returns ``True`` if the current log_level is less or equal to the specified log level. Otherwise ``False``.
- Example:
- ``should_log_le("info")`` will return ``True`` if the current log level is either ``logging.INFO`` or ``logging.DEBUG``
- """
- if not isinstance(max_log_level_str, str):
- raise ValueError(f"{max_log_level_str} is not a string")
- max_log_level_str = max_log_level_str.lower()
- if max_log_level_str not in log_levels:
- raise ValueError(f"{max_log_level_str} is not one of the `logging` levels")
- return get_current_level() <= log_levels[max_log_level_str]
|