123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import importlib
- import torch
- from torch import optim
- import numpy as np
- from inspect import isfunction
- from PIL import Image, ImageDraw, ImageFont
- def log_txt_as_img(wh, xc, size=10):
- # wh a tuple of (width, height)
- # xc a list of captions to plot
- b = len(xc)
- txts = list()
- for bi in range(b):
- txt = Image.new("RGB", wh, color="white")
- draw = ImageDraw.Draw(txt)
- font = ImageFont.truetype('fonts/DejaVuSans.ttf', size=size)
- nc = int(40 * (wh[0] / 256))
- lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
- try:
- draw.text((0, 0), lines, fill="black", font=font)
- except UnicodeEncodeError:
- print("Can't encode string for logging. Skipping.")
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
- txts.append(txt)
- txts = np.stack(txts)
- txts = torch.tensor(txts)
- return txts
- def ismap(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] > 3)
- def isimage(x):
- if not isinstance(x,torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
- def exists(x):
- return x is not None
- def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
- def mean_flat(tensor):
- """
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
- def count_params(model, verbose=False):
- total_params = sum(p.numel() for p in model.parameters())
- if verbose:
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
- return total_params
- def instantiate_from_config(config):
- if not "target" in config:
- if config == '__is_first_stage__':
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
- def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
- class AdamWwithEMAandWings(optim.Optimizer):
- # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
- def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
- weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
- ema_power=1., param_names=()):
- """AdamW that saves EMA versions of the parameters."""
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- if not 0.0 <= ema_decay <= 1.0:
- raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
- ema_power=ema_power, param_names=param_names)
- super().__init__(params, defaults)
- def __setstate__(self, state):
- super().__setstate__(state)
- for group in self.param_groups:
- group.setdefault('amsgrad', False)
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- params_with_grad = []
- grads = []
- exp_avgs = []
- exp_avg_sqs = []
- ema_params_with_grad = []
- state_sums = []
- max_exp_avg_sqs = []
- state_steps = []
- amsgrad = group['amsgrad']
- beta1, beta2 = group['betas']
- ema_decay = group['ema_decay']
- ema_power = group['ema_power']
- for p in group['params']:
- if p.grad is None:
- continue
- params_with_grad.append(p)
- if p.grad.is_sparse:
- raise RuntimeError('AdamW does not support sparse gradients')
- grads.append(p.grad)
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state['step'] = 0
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
- # Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
- if amsgrad:
- # Maintains max of all exp. moving avg. of sq. grad. values
- state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
- # Exponential moving average of parameter values
- state['param_exp_avg'] = p.detach().float().clone()
- exp_avgs.append(state['exp_avg'])
- exp_avg_sqs.append(state['exp_avg_sq'])
- ema_params_with_grad.append(state['param_exp_avg'])
- if amsgrad:
- max_exp_avg_sqs.append(state['max_exp_avg_sq'])
- # update the steps for each param group update
- state['step'] += 1
- # record the step after step update
- state_steps.append(state['step'])
- optim._functional.adamw(params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=amsgrad,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'],
- maximize=False)
- cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
- for param, ema_param in zip(params_with_grad, ema_params_with_grad):
- ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
- return loss
|