config.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. """
  3. Copyright (c) Microsoft Corporation
  4. Licensed under the MIT license.
  5. """
  6. from pydantic import root_validator
  7. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  8. def get_monitor_config(param_dict):
  9. monitor_dict = {
  10. key: param_dict.get(key,
  11. {})
  12. for key in ("tensorboard",
  13. "wandb",
  14. "csv_monitor")
  15. }
  16. return DeepSpeedMonitorConfig(**monitor_dict)
  17. class TensorBoardConfig(DeepSpeedConfigModel):
  18. """Sets parameters for TensorBoard monitor."""
  19. enabled: bool = False
  20. """ Whether logging to Tensorboard is enabled. Requires `tensorboard` package is installed. """
  21. output_path: str = ""
  22. """
  23. Path to where the Tensorboard logs will be written. If not provided, the
  24. output path is set under the training script’s launching path.
  25. """
  26. job_name: str = "DeepSpeedJobName"
  27. """ Name for the current job. This will become a new directory inside `output_path`. """
  28. class WandbConfig(DeepSpeedConfigModel):
  29. """Sets parameters for WandB monitor."""
  30. enabled: bool = False
  31. """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """
  32. group: str = None
  33. """ Name for the WandB group. This can be used to group together runs. """
  34. team: str = None
  35. """ Name for the WandB team. """
  36. project: str = "deepspeed"
  37. """ Name for the WandB project. """
  38. class CSVConfig(DeepSpeedConfigModel):
  39. """Sets parameters for CSV monitor."""
  40. enabled: bool = False
  41. """ Whether logging to local CSV files is enabled. """
  42. output_path: str = ""
  43. """
  44. Path to where the csv files will be written. If not provided, the output
  45. path is set under the training script’s launching path.
  46. """
  47. job_name: str = "DeepSpeedJobName"
  48. """ Name for the current job. This will become a new directory inside `output_path`. """
  49. class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
  50. """Sets parameters for various monitoring methods."""
  51. tensorboard: TensorBoardConfig = {}
  52. """ TensorBoard monitor, requires `tensorboard` package is installed. """
  53. wandb: WandbConfig = {}
  54. """ WandB monitor, requires `wandb` package is installed. """
  55. csv_monitor: CSVConfig = {}
  56. """ Local CSV output of monitoring data. """
  57. @root_validator
  58. def check_enabled(cls, values):
  59. values["enabled"] = False
  60. if (values.get("tensorboard").enabled or values.get("wandb").enabled
  61. or values.get("csv_monitor").enabled):
  62. values["enabled"] = True
  63. return values