123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- import logging
- import os
- import time
- import random
- import subprocess
- import sys
- from datetime import datetime
- import numpy as np
- import torch.utils.data
- from torch import nn
- from torch.utils.tensorboard import SummaryWriter
- from utils.commons.dataset_utils import data_loader
- from utils.commons.hparams import hparams
- from utils.commons.meters import AvgrageMeter
- from utils.commons.tensor_utils import tensors_to_scalars
- from utils.commons.trainer import Trainer
- from utils.nn.model_utils import print_arch, num_params
- torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
- log_format = '%(asctime)s %(message)s'
- logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format=log_format, datefmt='%m/%d %I:%M:%S %p')
- class BaseTask(nn.Module):
- def __init__(self, *args, **kwargs):
- super(BaseTask, self).__init__()
- self.current_epoch = 0
- self.global_step = 0
- self.trainer = None
- self.use_ddp = False
- self.gradient_clip_norm = hparams['clip_grad_norm']
- self.gradient_clip_val = hparams.get('clip_grad_value', 0)
- self.model = None
- self.epoch_training_losses_meter = None
- self.logger: SummaryWriter = None
- ######################
- # build model, dataloaders, optimizer, scheduler and tensorboard
- ######################
- def build_model(self):
- raise NotImplementedError
- @data_loader
- def train_dataloader(self):
- raise NotImplementedError
- @data_loader
- def test_dataloader(self):
- raise NotImplementedError
- @data_loader
- def val_dataloader(self):
- raise NotImplementedError
- def build_scheduler(self, optimizer):
- return None
- def build_optimizer(self, model):
- raise NotImplementedError
- def configure_optimizers(self):
- optm = self.build_optimizer(self.model)
- self.scheduler = self.build_scheduler(optm)
- if isinstance(optm, (list, tuple)):
- return optm
- return [optm]
- def build_tensorboard(self, save_dir, name, **kwargs):
- log_dir = os.path.join(save_dir, name)
- os.makedirs(log_dir, exist_ok=True)
- self.logger = SummaryWriter(log_dir=log_dir, **kwargs)
- ######################
- # training
- ######################
- def on_train_start(self):
- for n, m in self.model.named_children():
- num_params(m, model_name=n)
- if torch.__version__.split(".")[0] == '2' and hparams.get("torch_compile", False):
- self.model = torch.compile(self.model, mode='default')
- def on_train_end(self):
- pass
- def on_epoch_start(self):
- self.epoch_training_losses_meter = {'total_loss': AvgrageMeter()}
- def on_epoch_end(self):
- loss_outputs = {k: v.avg for k, v in self.epoch_training_losses_meter.items()}
- print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}")
- loss_outputs = {"epoch_mean/"+k:v for k,v in loss_outputs.items()}
- return loss_outputs
- def _training_step(self, sample, batch_idx, optimizer_idx):
- """
- :param sample:
- :param batch_idx:
- :return: total loss: torch.Tensor, loss_log: dict
- """
- raise NotImplementedError
- def training_step(self, sample, batch_idx, optimizer_idx=-1):
- """
- :param sample:
- :param batch_idx:
- :param optimizer_idx:
- :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict}
- """
- # perform the main training step in a specific task
- loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
- if loss_ret is None:
- return {'loss': None}
- total_loss, log_outputs = loss_ret
- log_outputs = tensors_to_scalars(log_outputs)
- # add to epoch meter
- for k, v in log_outputs.items():
- if '/' in k:
- k_split = k.split("/")
- assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `<tag>/<sub_tag>`"
- k = k.replace("/", "_")
- if k not in self.epoch_training_losses_meter:
- self.epoch_training_losses_meter[k] = AvgrageMeter()
- if not np.isnan(v):
- self.epoch_training_losses_meter[k].update(v)
-
- if optimizer_idx >= 0:
- for params_group_i in range(len(self.trainer.optimizers[optimizer_idx].param_groups)):
- log_outputs[f'lr/optimizer{optimizer_idx}_params_group{params_group_i}'] = self.trainer.optimizers[optimizer_idx].param_groups[params_group_i]['lr']
- # add to progress bar
- progress_bar_log = {}
- for k, v in log_outputs.items():
- if '/' in k:
- k_split = k.split("/")
- assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `<tag>/<sub_tag>`"
- k = k.replace("/", "_")
- assert k not in progress_bar_log, f"we got duplicate tags in log_outputs, check this `{k}`"
- progress_bar_log[k] = v
- # add to progress bar
- tb_log = {}
- for k, v in log_outputs.items():
- if '/' in k:
- tb_log[k] = v
- else:
- tb_log[f'tr/{k}'] = v
- if not isinstance(total_loss, torch.Tensor):
- return {'loss': None}
- self.epoch_training_losses_meter['total_loss'].update(total_loss.item())
- return {
- 'loss': total_loss,
- 'progress_bar': progress_bar_log,
- 'tb_log': tb_log
- }
- def on_before_optimization(self, opt_idx):
- if self.gradient_clip_norm > 0:
- torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm)
- if self.gradient_clip_val > 0:
- torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val)
- def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
- if self.scheduler is not None:
- self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
- ######################
- # validation
- ######################
- def validation_start(self):
- pass
- def validation_step(self, sample, batch_idx):
- """
- :param sample:
- :param batch_idx:
- :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict)
- """
- raise NotImplementedError
- def validation_end(self, outputs):
- """
- :param outputs:
- :return: loss_output: dict
- """
- all_losses_meter = {'total_loss': AvgrageMeter()}
- for output in outputs:
- if output is None or len(output) == 0:
- continue
- if isinstance(output, dict):
- assert 'losses' in output, 'Key "losses" should exist in validation output.'
- n = output.pop('nsamples', 1)
- losses = tensors_to_scalars(output['losses'])
- total_loss = output.get('total_loss', sum(losses.values()))
- else:
- assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)'
- n = 1
- total_loss, losses = output
- losses = tensors_to_scalars(losses)
- if isinstance(total_loss, torch.Tensor):
- total_loss = total_loss.item()
- for k, v in losses.items():
- if k not in all_losses_meter:
- all_losses_meter[k] = AvgrageMeter()
- all_losses_meter[k].update(v, n)
- all_losses_meter['total_loss'].update(total_loss, n)
- loss_output = {k: round(v.avg, 10) for k, v in all_losses_meter.items()}
- print(f"| Validation results@{self.global_step}: {loss_output}")
- return {
- 'tb_log': {f'val/{k}': v for k, v in loss_output.items()},
- 'val_loss': loss_output['total_loss']
- }
- ######################
- # testing
- ######################
- def test_start(self):
- pass
- def test_step(self, sample, batch_idx):
- return self.validation_step(sample, batch_idx)
- def test_end(self, outputs):
- return self.validation_end(outputs)
- ######################
- # start training/testing
- ######################
- @classmethod
- def start(cls):
- def is_port_in_use(port: int) -> bool:
- import socket
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return s.connect_ex(('localhost', port)) == 0
- os.environ['MASTER_PORT'] = str(random.randint(10000, 11000))
- while is_port_in_use(int(os.environ['MASTER_PORT'])):
- print(f"| Port {os.environ['MASTER_PORT']} is in use. Change another port...")
- os.environ['MASTER_PORT'] = str(random.randint(10000, 11000))
- time.sleep(1)
- random.seed(hparams['seed'])
- np.random.seed(hparams['seed'])
- work_dir = hparams['work_dir']
- trainer = Trainer(
- work_dir=work_dir,
- val_check_interval=hparams['val_check_interval'],
- tb_log_interval=hparams['tb_log_interval'],
- max_updates=hparams['max_updates'],
- num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000,
- accumulate_grad_batches=hparams['accumulate_grad_batches'],
- print_nan_grads=hparams['print_nan_grads'],
- resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0),
- amp=hparams['amp'],
- monitor_key=hparams['valid_monitor_key'],
- monitor_mode=hparams['valid_monitor_mode'],
- num_ckpt_keep=hparams['num_ckpt_keep'],
- save_best=hparams['save_best'],
- seed=hparams['seed'],
- debug=hparams['debug']
- )
- if not hparams['infer']: # train
- trainer.fit(cls)
- else:
- trainer.test(cls)
- def on_keyboard_interrupt(self):
- pass
|