config.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from pydantic import root_validator
  5. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  6. def get_monitor_config(param_dict):
  7. monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")}
  8. return DeepSpeedMonitorConfig(**monitor_dict)
  9. class TensorBoardConfig(DeepSpeedConfigModel):
  10. """Sets parameters for TensorBoard monitor."""
  11. enabled: bool = False
  12. """ Whether logging to Tensorboard is enabled. Requires `tensorboard` package is installed. """
  13. output_path: str = ""
  14. """
  15. Path to where the Tensorboard logs will be written. If not provided, the
  16. output path is set under the training script’s launching path.
  17. """
  18. job_name: str = "DeepSpeedJobName"
  19. """ Name for the current job. This will become a new directory inside `output_path`. """
  20. class WandbConfig(DeepSpeedConfigModel):
  21. """Sets parameters for WandB monitor."""
  22. enabled: bool = False
  23. """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """
  24. group: str = None
  25. """ Name for the WandB group. This can be used to group together runs. """
  26. team: str = None
  27. """ Name for the WandB team. """
  28. project: str = "deepspeed"
  29. """ Name for the WandB project. """
  30. class CSVConfig(DeepSpeedConfigModel):
  31. """Sets parameters for CSV monitor."""
  32. enabled: bool = False
  33. """ Whether logging to local CSV files is enabled. """
  34. output_path: str = ""
  35. """
  36. Path to where the csv files will be written. If not provided, the output
  37. path is set under the training script’s launching path.
  38. """
  39. job_name: str = "DeepSpeedJobName"
  40. """ Name for the current job. This will become a new directory inside `output_path`. """
  41. class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
  42. """Sets parameters for various monitoring methods."""
  43. tensorboard: TensorBoardConfig = {}
  44. """ TensorBoard monitor, requires `tensorboard` package is installed. """
  45. wandb: WandbConfig = {}
  46. """ WandB monitor, requires `wandb` package is installed. """
  47. csv_monitor: CSVConfig = {}
  48. """ Local CSV output of monitoring data. """
  49. @root_validator
  50. def check_enabled(cls, values):
  51. values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
  52. "csv_monitor").enabled
  53. return values