wandb.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .utils import check_wandb_availability
  3. from .monitor import Monitor
  4. import deepspeed.comm as dist
  5. class WandbMonitor(Monitor):
  6. def __init__(self, wandb_config):
  7. super().__init__(wandb_config)
  8. check_wandb_availability()
  9. import wandb
  10. self.enabled = wandb_config.enabled
  11. self.group = wandb_config.group
  12. self.team = wandb_config.team
  13. self.project = wandb_config.project
  14. if self.enabled and dist.get_rank() == 0:
  15. wandb.init(project=self.project, group=self.group, entity=self.team)
  16. def log(self, data, step=None, commit=None, sync=None):
  17. if self.enabled and dist.get_rank() == 0:
  18. import wandb
  19. return wandb.log(data, step=step, commit=commit, sync=sync)
  20. def write_events(self, event_list):
  21. if self.enabled and dist.get_rank() == 0:
  22. for event in event_list:
  23. label = event[0]
  24. value = event[1]
  25. step = event[2]
  26. self.log({label: value}, step=step)