1234567891011121314151617181920212223242526272829303132 |
- '''Copyright The Microsoft DeepSpeed Team'''
- """
- Copyright (c) Microsoft Corporation
- Licensed under the MIT license.
- """
- from pydantic import BaseModel
- from .constants import *
- class CommsConfig(BaseModel):
- class Config:
- validate_all = True
- validate_assignment = True
- use_enum_values = True
- extra = 'forbid'
- class CommsLoggerConfig(CommsConfig):
- enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
- prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
- prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
- verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
- debug: bool = COMMS_LOGGER_DEBUG_DEFAULT
- class DeepSpeedCommsConfig:
- def __init__(self, ds_config):
- self.comms_logger_enabled = 'comms_logger' in ds_config
- if self.comms_logger_enabled:
- self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])
|