schedulers.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import numpy as np
  2. from utils.commons.hparams import hparams
  3. class NoneSchedule(object):
  4. def __init__(self, optimizer, lr):
  5. self.optimizer = optimizer
  6. self.constant_lr = lr
  7. self.step(0)
  8. def step(self, num_updates):
  9. self.lr = self.constant_lr
  10. for param_group in self.optimizer.param_groups:
  11. param_group['lr'] = self.lr
  12. return self.lr
  13. def get_lr(self):
  14. return self.optimizer.param_groups[0]['lr']
  15. def get_last_lr(self):
  16. return self.get_lr()
  17. class RSQRTSchedule(NoneSchedule):
  18. def __init__(self, optimizer, lr, warmup_updates, hidden_size):
  19. self.optimizer = optimizer
  20. self.constant_lr = lr
  21. self.warmup_updates = warmup_updates
  22. self.hidden_size = hidden_size
  23. self.lr = lr
  24. for param_group in optimizer.param_groups:
  25. param_group['lr'] = self.lr
  26. self.step(0)
  27. def step(self, num_updates):
  28. constant_lr = self.constant_lr
  29. warmup = min(num_updates / self.warmup_updates, 1.0)
  30. rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
  31. rsqrt_hidden = self.hidden_size ** -0.5
  32. self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
  33. for param_group in self.optimizer.param_groups:
  34. param_group['lr'] = self.lr
  35. return self.lr
  36. class WarmupSchedule(NoneSchedule):
  37. def __init__(self, optimizer, lr, warmup_updates):
  38. self.optimizer = optimizer
  39. self.constant_lr = self.lr = lr
  40. self.warmup_updates = warmup_updates
  41. for param_group in optimizer.param_groups:
  42. param_group['lr'] = self.lr
  43. self.step(0)
  44. def step(self, num_updates):
  45. constant_lr = self.constant_lr
  46. warmup = min(num_updates / self.warmup_updates, 1.0)
  47. self.lr = max(constant_lr * warmup, 1e-7)
  48. for param_group in self.optimizer.param_groups:
  49. param_group['lr'] = self.lr
  50. return self.lr
  51. class ExponentialSchedule(NoneSchedule):
  52. def __init__(self, optimizer, lr, warmup_updates):
  53. self.optimizer = optimizer
  54. self.constant_lr = self.lr = lr
  55. self.warmup_updates = warmup_updates
  56. for param_group in optimizer.param_groups:
  57. param_group['lr'] = self.lr
  58. self.step(0)
  59. def step(self, num_updates):
  60. constant_lr = self.constant_lr
  61. if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
  62. warmup = min(num_updates / self.warmup_updates, 1.0)
  63. self.lr = max(constant_lr * warmup, 1e-7)
  64. else:
  65. new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 250k steps
  66. self.lr = max(new_lrate, hparams.get("min_lr", 1e-6))
  67. for param_group in self.optimizer.param_groups:
  68. param_group['lr'] = self.lr
  69. return self.lr
  70. class ExponentialScheduleWithAudattNet(NoneSchedule):
  71. """
  72. Default Scheduler in AD-NeRF
  73. for audatt net, since it starts at 20_0000 steps, we need to enlarge its lr
  74. in optimizer, we set param_groups[1] to optimize audatt net
  75. """
  76. def __init__(self, optimizer, lr, warmup_updates=0):
  77. self.optimizer = optimizer
  78. self.constant_lr = self.lr = lr
  79. self.warmup_updates = warmup_updates
  80. optimizer.param_groups[0]['lr'] = self.lr
  81. optimizer.param_groups[1]['lr'] = self.lr * 5
  82. self.step(0)
  83. def step(self, num_updates):
  84. constant_lr = self.constant_lr
  85. if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
  86. warmup = min(num_updates / self.warmup_updates, 1.0)
  87. self.lr = max(constant_lr * warmup, 1e-7)
  88. else:
  89. new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 250k steps
  90. self.lr = max(new_lrate, 1e-7)
  91. self.optimizer.param_groups[0]['lr'] = self.lr
  92. self.optimizer.param_groups[1]['lr'] = self.lr * 5
  93. return self.lr
  94. class ExponentialScheduleForRADNeRF(NoneSchedule):
  95. """
  96. Default Scheduler in RAD-NeRF
  97. RAD-NeRF has two groups of params with different lr
  98. for tileGrid embedding, the lr=5e-3
  99. for other network params, the lr=5e-4
  100. """
  101. def __init__(self, optimizer, lr, warmup_updates=0):
  102. self.optimizer = optimizer
  103. self.constant_lr = self.lr = lr # 0.0005
  104. self.warmup_updates = warmup_updates
  105. self.finetune_lips = hparams['finetune_lips']
  106. self.finetune_lips_start_iter = hparams['finetune_lips_start_iter']
  107. optimizer.param_groups[0]['lr'] = self.lr # for Net_params in RAD-NeRF, lr starts from 0.0005
  108. optimizer.param_groups[1]['lr'] = self.lr * 10 # for tileGrid, lr starts from 0.005
  109. optimizer.param_groups[2]['lr'] = self.lr * 5 # for Att Net, lr starts from 0.0025
  110. self.step(0)
  111. def step(self, num_updates):
  112. constant_lr = self.constant_lr
  113. if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
  114. warmup = min(num_updates / self.warmup_updates, 1.0)
  115. self.lr = max(constant_lr * warmup, 1e-5)
  116. else:
  117. if self.finetune_lips and num_updates > self.finetune_lips_start_iter:
  118. new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.05x for every 200k steps
  119. else:
  120. new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 200k steps
  121. self.lr = max(new_lrate, 1e-5)
  122. self.optimizer.param_groups[0]['lr'] = self.lr
  123. self.optimizer.param_groups[1]['lr'] = self.lr * 10
  124. self.optimizer.param_groups[2]['lr'] = self.lr * 5
  125. return self.lr
  126. class ExponentialScheduleForRADNeRFTorso(NoneSchedule):
  127. """
  128. Default Scheduler in RAD-NeRF
  129. RAD-NeRF has two groups of params with different lr
  130. for tileGrid embedding, the lr=5e-3
  131. for other network params, the lr=5e-4
  132. """
  133. def __init__(self, optimizer, lr, warmup_updates=0):
  134. self.optimizer = optimizer
  135. self.constant_lr = self.lr = lr # 0.0005
  136. self.warmup_updates = warmup_updates
  137. optimizer.param_groups[0]['lr'] = self.lr # for Net_params in RAD-NeRF, lr starts from 0.0005
  138. optimizer.param_groups[1]['lr'] = self.lr * 10 # for tileGrid, lr starts from 0.005
  139. self.step(0)
  140. def step(self, num_updates):
  141. constant_lr = self.constant_lr
  142. if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
  143. warmup = min(num_updates / self.warmup_updates, 1.0)
  144. self.lr = max(constant_lr * warmup, 1e-5)
  145. else:
  146. new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 200k steps
  147. self.lr = max(new_lrate, 1e-5)
  148. self.optimizer.param_groups[0]['lr'] = self.lr
  149. self.optimizer.param_groups[1]['lr'] = self.lr * 10
  150. return self.lr
  151. class CosineSchedule(NoneSchedule):
  152. def __init__(self, optimizer, lr, warmup_updates, total_updates):
  153. self.optimizer = optimizer
  154. self.constant_lr = lr
  155. self.warmup_updates = warmup_updates
  156. self.total_updates = total_updates
  157. self.lr = lr
  158. self.assign_learning_rate(self.optimizer, self.lr)
  159. self.step(0)
  160. def assign_learning_rate(self, optimizer, new_lr):
  161. for param_group in optimizer.param_groups:
  162. param_group["lr"] = new_lr
  163. def _warmup_lr(self, base_lr, warmup_length, step):
  164. return base_lr * (step + 1) / warmup_length
  165. def step(self, num_updates):
  166. if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
  167. lr = self._warmup_lr(self.lr, self.warmup_updates, num_updates)
  168. elif num_updates <= self.total_updates:
  169. e = num_updates - self.warmup_updates
  170. es = self.total_updates - self.warmup_updates
  171. lr = 0.5 * (1 + np.cos(np.pi * e / es)) * self.lr
  172. else:
  173. lr = 1e-5
  174. lr = max(1e-5, lr)
  175. self.assign_learning_rate(self.optimizer, lr)
  176. return lr