12345678910111213141516171819202122232425262728293031323334 |
- '''Copyright The Microsoft DeepSpeed Team'''
- from .utils import check_wandb_availability
- from .monitor import Monitor
- import deepspeed.comm as dist
- class WandbMonitor(Monitor):
- def __init__(self, wandb_config):
- super().__init__(wandb_config)
- check_wandb_availability()
- import wandb
- self.enabled = wandb_config.enabled
- self.group = wandb_config.group
- self.team = wandb_config.team
- self.project = wandb_config.project
- if self.enabled and dist.get_rank() == 0:
- wandb.init(project=self.project, group=self.group, entity=self.team)
- def log(self, data, step=None, commit=None, sync=None):
- if self.enabled and dist.get_rank() == 0:
- import wandb
- return wandb.log(data, step=step, commit=commit, sync=sync)
- def write_events(self, event_list):
- if self.enabled and dist.get_rank() == 0:
- for event in event_list:
- label = event[0]
- value = event[1]
- step = event[2]
- self.log({label: value}, step=step)
|