tensorboard.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .utils import check_tb_availability
  5. from .monitor import Monitor
  6. import os
  7. import deepspeed.comm as dist
  8. class TensorBoardMonitor(Monitor):
  9. def __init__(self, tensorboard_config):
  10. super().__init__(tensorboard_config)
  11. check_tb_availability()
  12. self.summary_writer = None
  13. self.enabled = tensorboard_config.enabled
  14. self.output_path = tensorboard_config.output_path
  15. self.job_name = tensorboard_config.job_name
  16. if self.enabled and dist.get_rank() == 0:
  17. self.get_summary_writer()
  18. def get_summary_writer(self, base=os.path.join(os.path.expanduser("~"), "tensorboard")):
  19. if self.enabled and dist.get_rank() == 0:
  20. from torch.utils.tensorboard import SummaryWriter
  21. if self.output_path is not None:
  22. log_dir = os.path.join(self.output_path, self.job_name)
  23. # NOTE: This code path currently is never used since the default output_path is an empty string and not None. Saving it in case we want this functionality in the future.
  24. else:
  25. if "DLWS_JOB_ID" in os.environ:
  26. infra_job_id = os.environ["DLWS_JOB_ID"]
  27. elif "DLTS_JOB_ID" in os.environ:
  28. infra_job_id = os.environ["DLTS_JOB_ID"]
  29. else:
  30. infra_job_id = "unknown-job-id"
  31. summary_writer_dir_name = os.path.join(infra_job_id, "logs")
  32. log_dir = os.path.join(base, summary_writer_dir_name, self.output_path)
  33. os.makedirs(log_dir, exist_ok=True)
  34. self.summary_writer = SummaryWriter(log_dir=log_dir)
  35. return self.summary_writer
  36. def write_events(self, event_list, flush=True):
  37. if self.enabled and self.summary_writer is not None and dist.get_rank() == 0:
  38. for event in event_list:
  39. self.summary_writer.add_scalar(*event)
  40. if flush:
  41. self.summary_writer.flush()
  42. def flush(self):
  43. if self.enabled and self.summary_writer is not None and dist.get_rank() == 0:
  44. self.summary_writer.flush()