base_task.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import logging
  2. import os
  3. import time
  4. import random
  5. import subprocess
  6. import sys
  7. from datetime import datetime
  8. import numpy as np
  9. import torch.utils.data
  10. from torch import nn
  11. from torch.utils.tensorboard import SummaryWriter
  12. from utils.commons.dataset_utils import data_loader
  13. from utils.commons.hparams import hparams
  14. from utils.commons.meters import AvgrageMeter
  15. from utils.commons.tensor_utils import tensors_to_scalars
  16. from utils.commons.trainer import Trainer
  17. from utils.nn.model_utils import print_arch, num_params
  18. torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
  19. log_format = '%(asctime)s %(message)s'
  20. logging.basicConfig(stream=sys.stdout, level=logging.INFO,
  21. format=log_format, datefmt='%m/%d %I:%M:%S %p')
  22. class BaseTask(nn.Module):
  23. def __init__(self, *args, **kwargs):
  24. super(BaseTask, self).__init__()
  25. self.current_epoch = 0
  26. self.global_step = 0
  27. self.trainer = None
  28. self.use_ddp = False
  29. self.gradient_clip_norm = hparams['clip_grad_norm']
  30. self.gradient_clip_val = hparams.get('clip_grad_value', 0)
  31. self.model = None
  32. self.epoch_training_losses_meter = None
  33. self.logger: SummaryWriter = None
  34. ######################
  35. # build model, dataloaders, optimizer, scheduler and tensorboard
  36. ######################
  37. def build_model(self):
  38. raise NotImplementedError
  39. @data_loader
  40. def train_dataloader(self):
  41. raise NotImplementedError
  42. @data_loader
  43. def test_dataloader(self):
  44. raise NotImplementedError
  45. @data_loader
  46. def val_dataloader(self):
  47. raise NotImplementedError
  48. def build_scheduler(self, optimizer):
  49. return None
  50. def build_optimizer(self, model):
  51. raise NotImplementedError
  52. def configure_optimizers(self):
  53. optm = self.build_optimizer(self.model)
  54. self.scheduler = self.build_scheduler(optm)
  55. if isinstance(optm, (list, tuple)):
  56. return optm
  57. return [optm]
  58. def build_tensorboard(self, save_dir, name, **kwargs):
  59. log_dir = os.path.join(save_dir, name)
  60. os.makedirs(log_dir, exist_ok=True)
  61. self.logger = SummaryWriter(log_dir=log_dir, **kwargs)
  62. ######################
  63. # training
  64. ######################
  65. def on_train_start(self):
  66. for n, m in self.model.named_children():
  67. num_params(m, model_name=n)
  68. if torch.__version__.split(".")[0] == '2' and hparams.get("torch_compile", False):
  69. self.model = torch.compile(self.model, mode='default')
  70. def on_train_end(self):
  71. pass
  72. def on_epoch_start(self):
  73. self.epoch_training_losses_meter = {'total_loss': AvgrageMeter()}
  74. def on_epoch_end(self):
  75. loss_outputs = {k: v.avg for k, v in self.epoch_training_losses_meter.items()}
  76. print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}")
  77. loss_outputs = {"epoch_mean/"+k:v for k,v in loss_outputs.items()}
  78. return loss_outputs
  79. def _training_step(self, sample, batch_idx, optimizer_idx):
  80. """
  81. :param sample:
  82. :param batch_idx:
  83. :return: total loss: torch.Tensor, loss_log: dict
  84. """
  85. raise NotImplementedError
  86. def training_step(self, sample, batch_idx, optimizer_idx=-1):
  87. """
  88. :param sample:
  89. :param batch_idx:
  90. :param optimizer_idx:
  91. :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict}
  92. """
  93. # perform the main training step in a specific task
  94. loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
  95. if loss_ret is None:
  96. return {'loss': None}
  97. total_loss, log_outputs = loss_ret
  98. log_outputs = tensors_to_scalars(log_outputs)
  99. # add to epoch meter
  100. for k, v in log_outputs.items():
  101. if '/' in k:
  102. k_split = k.split("/")
  103. assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `<tag>/<sub_tag>`"
  104. k = k.replace("/", "_")
  105. if k not in self.epoch_training_losses_meter:
  106. self.epoch_training_losses_meter[k] = AvgrageMeter()
  107. if not np.isnan(v):
  108. self.epoch_training_losses_meter[k].update(v)
  109. if optimizer_idx >= 0:
  110. for params_group_i in range(len(self.trainer.optimizers[optimizer_idx].param_groups)):
  111. log_outputs[f'lr/optimizer{optimizer_idx}_params_group{params_group_i}'] = self.trainer.optimizers[optimizer_idx].param_groups[params_group_i]['lr']
  112. # add to progress bar
  113. progress_bar_log = {}
  114. for k, v in log_outputs.items():
  115. if '/' in k:
  116. k_split = k.split("/")
  117. assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `<tag>/<sub_tag>`"
  118. k = k.replace("/", "_")
  119. assert k not in progress_bar_log, f"we got duplicate tags in log_outputs, check this `{k}`"
  120. progress_bar_log[k] = v
  121. # add to progress bar
  122. tb_log = {}
  123. for k, v in log_outputs.items():
  124. if '/' in k:
  125. tb_log[k] = v
  126. else:
  127. tb_log[f'tr/{k}'] = v
  128. if not isinstance(total_loss, torch.Tensor):
  129. return {'loss': None}
  130. self.epoch_training_losses_meter['total_loss'].update(total_loss.item())
  131. return {
  132. 'loss': total_loss,
  133. 'progress_bar': progress_bar_log,
  134. 'tb_log': tb_log
  135. }
  136. def on_before_optimization(self, opt_idx):
  137. if self.gradient_clip_norm > 0:
  138. torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm)
  139. if self.gradient_clip_val > 0:
  140. torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val)
  141. def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
  142. if self.scheduler is not None:
  143. self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
  144. ######################
  145. # validation
  146. ######################
  147. def validation_start(self):
  148. pass
  149. def validation_step(self, sample, batch_idx):
  150. """
  151. :param sample:
  152. :param batch_idx:
  153. :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict)
  154. """
  155. raise NotImplementedError
  156. def validation_end(self, outputs):
  157. """
  158. :param outputs:
  159. :return: loss_output: dict
  160. """
  161. all_losses_meter = {'total_loss': AvgrageMeter()}
  162. for output in outputs:
  163. if output is None or len(output) == 0:
  164. continue
  165. if isinstance(output, dict):
  166. assert 'losses' in output, 'Key "losses" should exist in validation output.'
  167. n = output.pop('nsamples', 1)
  168. losses = tensors_to_scalars(output['losses'])
  169. total_loss = output.get('total_loss', sum(losses.values()))
  170. else:
  171. assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)'
  172. n = 1
  173. total_loss, losses = output
  174. losses = tensors_to_scalars(losses)
  175. if isinstance(total_loss, torch.Tensor):
  176. total_loss = total_loss.item()
  177. for k, v in losses.items():
  178. if k not in all_losses_meter:
  179. all_losses_meter[k] = AvgrageMeter()
  180. all_losses_meter[k].update(v, n)
  181. all_losses_meter['total_loss'].update(total_loss, n)
  182. loss_output = {k: round(v.avg, 10) for k, v in all_losses_meter.items()}
  183. print(f"| Validation results@{self.global_step}: {loss_output}")
  184. return {
  185. 'tb_log': {f'val/{k}': v for k, v in loss_output.items()},
  186. 'val_loss': loss_output['total_loss']
  187. }
  188. ######################
  189. # testing
  190. ######################
  191. def test_start(self):
  192. pass
  193. def test_step(self, sample, batch_idx):
  194. return self.validation_step(sample, batch_idx)
  195. def test_end(self, outputs):
  196. return self.validation_end(outputs)
  197. ######################
  198. # start training/testing
  199. ######################
  200. @classmethod
  201. def start(cls):
  202. def is_port_in_use(port: int) -> bool:
  203. import socket
  204. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  205. return s.connect_ex(('localhost', port)) == 0
  206. os.environ['MASTER_PORT'] = str(random.randint(10000, 11000))
  207. while is_port_in_use(int(os.environ['MASTER_PORT'])):
  208. print(f"| Port {os.environ['MASTER_PORT']} is in use. Change another port...")
  209. os.environ['MASTER_PORT'] = str(random.randint(10000, 11000))
  210. time.sleep(1)
  211. random.seed(hparams['seed'])
  212. np.random.seed(hparams['seed'])
  213. work_dir = hparams['work_dir']
  214. trainer = Trainer(
  215. work_dir=work_dir,
  216. val_check_interval=hparams['val_check_interval'],
  217. tb_log_interval=hparams['tb_log_interval'],
  218. max_updates=hparams['max_updates'],
  219. num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000,
  220. accumulate_grad_batches=hparams['accumulate_grad_batches'],
  221. print_nan_grads=hparams['print_nan_grads'],
  222. resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0),
  223. amp=hparams['amp'],
  224. monitor_key=hparams['valid_monitor_key'],
  225. monitor_mode=hparams['valid_monitor_mode'],
  226. num_ckpt_keep=hparams['num_ckpt_keep'],
  227. save_best=hparams['save_best'],
  228. seed=hparams['seed'],
  229. debug=hparams['debug']
  230. )
  231. if not hparams['infer']: # train
  232. trainer.fit(cls)
  233. else:
  234. trainer.test(cls)
  235. def on_keyboard_interrupt(self):
  236. pass