123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- '''Copyright The Microsoft DeepSpeed Team'''
- """
- Copyright (c) Microsoft Corporation
- Licensed under the MIT license.
- """
- from pydantic import root_validator
- from deepspeed.runtime.config_utils import DeepSpeedConfigModel
- def get_monitor_config(param_dict):
- monitor_dict = {
- key: param_dict.get(key,
- {})
- for key in ("tensorboard",
- "wandb",
- "csv_monitor")
- }
- return DeepSpeedMonitorConfig(**monitor_dict)
- class TensorBoardConfig(DeepSpeedConfigModel):
- """Sets parameters for TensorBoard monitor."""
- enabled: bool = False
- """ Whether logging to Tensorboard is enabled. Requires `tensorboard` package is installed. """
- output_path: str = ""
- """
- Path to where the Tensorboard logs will be written. If not provided, the
- output path is set under the training script’s launching path.
- """
- job_name: str = "DeepSpeedJobName"
- """ Name for the current job. This will become a new directory inside `output_path`. """
- class WandbConfig(DeepSpeedConfigModel):
- """Sets parameters for WandB monitor."""
- enabled: bool = False
- """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """
- group: str = None
- """ Name for the WandB group. This can be used to group together runs. """
- team: str = None
- """ Name for the WandB team. """
- project: str = "deepspeed"
- """ Name for the WandB project. """
- class CSVConfig(DeepSpeedConfigModel):
- """Sets parameters for CSV monitor."""
- enabled: bool = False
- """ Whether logging to local CSV files is enabled. """
- output_path: str = ""
- """
- Path to where the csv files will be written. If not provided, the output
- path is set under the training script’s launching path.
- """
- job_name: str = "DeepSpeedJobName"
- """ Name for the current job. This will become a new directory inside `output_path`. """
- class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
- """Sets parameters for various monitoring methods."""
- tensorboard: TensorBoardConfig = {}
- """ TensorBoard monitor, requires `tensorboard` package is installed. """
- wandb: WandbConfig = {}
- """ WandB monitor, requires `wandb` package is installed. """
- csv_monitor: CSVConfig = {}
- """ Local CSV output of monitoring data. """
- @root_validator
- def check_enabled(cls, values):
- values["enabled"] = False
- if (values.get("tensorboard").enabled or values.get("wandb").enabled
- or values.get("csv_monitor").enabled):
- values["enabled"] = True
- return values
|