wandb.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .utils import check_wandb_availability
  5. from .monitor import Monitor
  6. import deepspeed.comm as dist
  7. class WandbMonitor(Monitor):
  8. def __init__(self, wandb_config):
  9. super().__init__(wandb_config)
  10. check_wandb_availability()
  11. import wandb
  12. self.enabled = wandb_config.enabled
  13. self.group = wandb_config.group
  14. self.team = wandb_config.team
  15. self.project = wandb_config.project
  16. if self.enabled and dist.get_rank() == 0:
  17. wandb.init(project=self.project, group=self.group, entity=self.team)
  18. def log(self, data, step=None, commit=None, sync=None):
  19. if self.enabled and dist.get_rank() == 0:
  20. import wandb
  21. return wandb.log(data, step=step, commit=commit, sync=sync)
  22. def write_events(self, event_list):
  23. if self.enabled and dist.get_rank() == 0:
  24. for event in event_list:
  25. label = event[0]
  26. value = event[1]
  27. step = event[2]
  28. self.log({label: value}, step=step)