123456789101112131415161718192021222324252627282930313233 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from pydantic import BaseModel
- from .constants import *
- class CommsConfig(BaseModel):
- class Config:
- validate_all = True
- validate_assignment = True
- use_enum_values = True
- extra = 'forbid'
- class CommsLoggerConfig(CommsConfig):
- enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
- prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
- prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
- verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
- debug: bool = COMMS_LOGGER_DEBUG_DEFAULT
- class DeepSpeedCommsConfig:
- def __init__(self, ds_config):
- self.comms_logger_enabled = 'comms_logger' in ds_config
- if self.comms_logger_enabled:
- self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])
|