config.py 860 B

123456789101112131415161718192021222324252627282930313233
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .constants import *
  5. from ..pydantic_v1 import BaseModel
  6. class CommsConfig(BaseModel):
  7. class Config:
  8. validate_all = True
  9. validate_assignment = True
  10. use_enum_values = True
  11. extra = 'forbid'
  12. class CommsLoggerConfig(CommsConfig):
  13. enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
  14. prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
  15. prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
  16. verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
  17. debug: bool = COMMS_LOGGER_DEBUG_DEFAULT
  18. class DeepSpeedCommsConfig:
  19. def __init__(self, ds_config):
  20. self.comms_logger_enabled = 'comms_logger' in ds_config
  21. if self.comms_logger_enabled:
  22. self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])