config.py 879 B

1234567891011121314151617181920212223242526272829303132
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. """
  3. Copyright (c) Microsoft Corporation
  4. Licensed under the MIT license.
  5. """
  6. from pydantic import BaseModel
  7. from .constants import *
  8. class CommsConfig(BaseModel):
  9. class Config:
  10. validate_all = True
  11. validate_assignment = True
  12. use_enum_values = True
  13. extra = 'forbid'
  14. class CommsLoggerConfig(CommsConfig):
  15. enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
  16. prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
  17. prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
  18. verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
  19. debug: bool = COMMS_LOGGER_DEBUG_DEFAULT
  20. class DeepSpeedCommsConfig:
  21. def __init__(self, ds_config):
  22. self.comms_logger_enabled = 'comms_logger' in ds_config
  23. if self.comms_logger_enabled:
  24. self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])