radnerf_sr.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import copy
  5. import random
  6. from modules.radnerfs.encoders.encoding import get_encoder
  7. from modules.radnerfs.renderer import NeRFRenderer
  8. from modules.radnerfs.cond_encoder import AudioNet, AudioAttNet, MLP, HeatMapEncoder, HeatMapAttNet
  9. from modules.radnerfs.utils import trunc_exp
  10. from modules.eg3ds.models.superresolution import *
  11. class Superresolution(torch.nn.Module):
  12. def __init__(self, channels=32, img_resolution=512, sr_antialias=True):
  13. super().__init__()
  14. assert img_resolution == 512
  15. block_kwargs = {'channel_base': 32768, 'channel_max': 512, 'fused_modconv_default': 'inference_only'}
  16. use_fp16 = True
  17. self.sr_antialias = sr_antialias
  18. self.input_resolution = 256
  19. # w_dim is not necessary, will be mul by 0
  20. self.w_dim = 16
  21. self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=self.w_dim, resolution=256,
  22. img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
  23. self.block1 = SynthesisBlock(128, 64, w_dim=self.w_dim, resolution=512,
  24. img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
  25. self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
  26. def forward(self, rgb, **block_kwargs):
  27. x = rgb
  28. ws = torch.ones([rgb.shape[0], 14, self.w_dim], dtype=rgb.dtype, device=rgb.device)
  29. ws = ws[:, -1:, :].repeat(1, 3, 1)
  30. if x.shape[-1] < self.input_resolution:
  31. x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
  32. mode='bilinear', align_corners=False, antialias=self.sr_antialias)
  33. rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
  34. mode='bilinear', align_corners=False, antialias=self.sr_antialias)
  35. x, rgb = self.block0(x, rgb, ws, **block_kwargs)
  36. x, rgb = self.block1(x, rgb, ws, **block_kwargs)
  37. return rgb
  38. class RADNeRFwithSR(NeRFRenderer):
  39. def __init__(self, hparams):
  40. super().__init__(hparams)
  41. self.hparams = copy.deepcopy(hparams)
  42. if hparams['cond_type'] == 'esperanto':
  43. self.cond_in_dim = 44
  44. elif hparams['cond_type'] == 'deepspeech':
  45. self.cond_in_dim = 29
  46. elif hparams['cond_type'] == 'idexp_lm3d_normalized':
  47. keypoint_mode = hparams.get("nerf_keypoint_mode", "lm68")
  48. if keypoint_mode == 'lm68':
  49. self.cond_in_dim = 68*3
  50. elif keypoint_mode == 'lm131':
  51. self.cond_in_dim = 131*3
  52. elif keypoint_mode == 'lm468':
  53. self.cond_in_dim = 468*3
  54. else:
  55. raise NotImplementedError()
  56. else:
  57. raise NotImplementedError()
  58. # a prenet that processes the raw condition
  59. self.cond_out_dim = hparams['cond_out_dim'] // 2 * 2
  60. self.cond_win_size = hparams['cond_win_size']
  61. self.smo_win_size = hparams['smo_win_size']
  62. self.cond_prenet = AudioNet(self.cond_in_dim, self.cond_out_dim, win_size=self.cond_win_size)
  63. if hparams.get("add_eye_blink_cond", False):
  64. self.blink_embedding = nn.Embedding(1, self.cond_out_dim//2)
  65. self.blink_encoder = nn.Sequential(
  66. *[
  67. nn.Linear(self.cond_out_dim//2, self.cond_out_dim//2),
  68. nn.Linear(self.cond_out_dim//2, hparams['eye_blink_dim']),
  69. ]
  70. )
  71. # a attention net that smoothes the condition feat sequence
  72. self.with_att = hparams['with_att']
  73. if self.with_att:
  74. self.cond_att_net = AudioAttNet(self.cond_out_dim, seq_len=self.smo_win_size)
  75. # a ambient network that predict the 2D ambient coordinate
  76. # the ambient grid models the dynamic of canonical face
  77. # by predict ambient coords given cond_feat, we can be driven the face by either audio or landmark!
  78. self.grid_type = hparams['grid_type'] # tiledgrid or hashgrid
  79. self.grid_interpolation_type = hparams['grid_interpolation_type'] # smoothstep or linear
  80. self.position_embedder, self.position_embedding_dim = get_encoder(self.grid_type, input_dim=3, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=hparams['log2_hashmap_size'], desired_resolution=hparams['desired_resolution'] * self.bound, interpolation=self.grid_interpolation_type)
  81. self.num_layers_ambient = hparams['num_layers_ambient']
  82. self.hidden_dim_ambient = hparams['hidden_dim_ambient']
  83. self.ambient_coord_dim = hparams['ambient_coord_dim']
  84. self.ambient_net = MLP(self.position_embedding_dim + self.cond_out_dim, self.ambient_coord_dim, self.hidden_dim_ambient, self.num_layers_ambient)
  85. # the learnable ambient grid
  86. self.ambient_embedder, self.ambient_embedding_dim = get_encoder(self.grid_type, input_dim=hparams['ambient_coord_dim'], num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=hparams['log2_hashmap_size'], desired_resolution=hparams['desired_resolution'], interpolation=self.grid_interpolation_type)
  87. # sigma network
  88. self.num_layers_sigma = hparams['num_layers_sigma']
  89. self.hidden_dim_sigma = hparams['hidden_dim_sigma']
  90. self.geo_feat_dim = hparams['geo_feat_dim']
  91. self.sigma_net = MLP(self.position_embedding_dim + self.ambient_embedding_dim, 1 + self.geo_feat_dim, self.hidden_dim_sigma, self.num_layers_sigma)
  92. # color network
  93. self.num_layers_color = hparams['num_layers_color']
  94. self.hidden_dim_color = hparams['hidden_dim_color']
  95. self.direction_embedder, self.direction_embedding_dim = get_encoder('spherical_harmonics')
  96. self.color_net = MLP(self.direction_embedding_dim + self.geo_feat_dim + self.individual_embedding_dim, 3, self.hidden_dim_color, self.num_layers_color)
  97. self.dropout = nn.Dropout(p=hparams['cond_dropout_rate'], inplace=False)
  98. self.sr_net = Superresolution(channels=3)
  99. self.lambda_ambient = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=False)
  100. def on_train_nerf(self):
  101. self.requires_grad_(True)
  102. self.sr_net.requires_grad_(False)
  103. def on_train_superresolution(self):
  104. self.requires_grad_(False)
  105. self.sr_net.requires_grad_(True)
  106. def cal_cond_feat(self, cond, eye_area_percent=None):
  107. """
  108. cond: [B, T, Ç]
  109. if deepspeech, [1/8, T=16, 29]
  110. if eserpanto, [1/8, T=16, 44]
  111. if idexp_lm3d_normalized, [1/5, T=1, 204]
  112. """
  113. cond_feat = self.cond_prenet(cond)
  114. if hparams.get("add_eye_blink_cond", False):
  115. if eye_area_percent is None:
  116. eye_area_percent = torch.zeros([1,1], dtype=cond_feat.dtype)
  117. blink_feat = self.blink_embedding(torch.tensor(0, device=cond_feat.device)).reshape([1, -1])
  118. blink_feat = blink_feat * eye_area_percent.reshape([1,1]).to(cond_feat.device)
  119. blink_feat = self.blink_encoder(blink_feat)
  120. cond_feat[..., :hparams['eye_blink_dim']] = cond_feat[..., :hparams['eye_blink_dim']] + blink_feat.expand(cond_feat[..., :hparams['eye_blink_dim']].shape)
  121. if self.with_att:
  122. cond_feat = self.cond_att_net(cond_feat) # [1, 64]
  123. return cond_feat
  124. def forward(self, position, direction, cond_feat, individual_code, cond_mask=None):
  125. """
  126. position: [N, 3], position, in [-bound, bound]
  127. direction: [N, 3], direction, nomalized in [-1, 1]
  128. cond_feat: [1, cond_dim], condition encoding, generated by self.cal_cond_feat
  129. individual_code: [1, ind_dim], individual code for each timestep
  130. """
  131. cond_feat = cond_feat.repeat([position.shape[0], 1]) # [1,cond_dim] ==> [N, cond_dim]
  132. pos_feat = self.position_embedder(position, bound=self.bound) # spatial feat f after E^3_{spatial} 3D grid in the paper
  133. # ambient
  134. ambient_inp = torch.cat([pos_feat, cond_feat], dim=1) # audio feat and spatial feat
  135. ambient_logit = self.ambient_net(ambient_inp).float() # the MLP after AFE in paper, use float(), prevent performance drop due to amp
  136. ambient_pos = torch.tanh(ambient_logit) # normalized to [-1, 1], act as the coordinate in the 2D ambient tilegrid
  137. ambient_feat = self.ambient_embedder(ambient_pos, bound=1) # E^2_{audio} in paper, 2D grid
  138. # sigma
  139. h = torch.cat([pos_feat, ambient_feat], dim=-1)
  140. h = self.sigma_net(h)
  141. sigma = trunc_exp(h[..., 0])
  142. geo_feat = h[..., 1:]
  143. # color
  144. direction_feat = self.direction_embedder(direction)
  145. if individual_code is not None:
  146. color_inp = torch.cat([direction_feat, geo_feat, individual_code.repeat(position.shape[0], 1)], dim=-1)
  147. else:
  148. color_inp = torch.cat([direction_feat, geo_feat], dim=-1)
  149. color_logit = self.color_net(color_inp)
  150. # sigmoid activation for rgb
  151. color = torch.sigmoid(color_logit)
  152. return sigma, color, ambient_pos
  153. def density(self, position, cond_feat, e=None, cond_mask=None):
  154. """
  155. Calculate Density, this is a sub-process of self.forward
  156. """
  157. assert self.hparams.get("to_heatmap", False) is False
  158. cond_feat = cond_feat.repeat([position.shape[0], 1]) # [1,cond_dim] ==> [N, cond_dim]
  159. pos_feat = self.position_embedder(position, bound=self.bound) # spatial feat f after E^3_{spatial} 3D grid in the paper
  160. # ambient
  161. ambient_inp = torch.cat([pos_feat, cond_feat], dim=1) # audio feat and spatial feat
  162. ambient_logit = self.ambient_net(ambient_inp).float() # the MLP after AFE in paper, use float(), prevent performance drop due to amp
  163. ambient_pos = torch.tanh(ambient_logit) # normalized to [-1, 1], act as the coordinate in the 2D ambient tilegrid
  164. ambient_feat = self.ambient_embedder(ambient_pos, bound=1) # E^2_{audio} in paper, 2D grid
  165. # sigma
  166. h = torch.cat([pos_feat, ambient_feat], dim=-1)
  167. h = self.sigma_net(h)
  168. sigma = trunc_exp(h[..., 0])
  169. geo_feat = h[..., 1:]
  170. return {
  171. 'sigma': sigma,
  172. 'geo_feat': geo_feat,
  173. }
  174. def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, cond_mask=None, eye_area_percent=None, **kwargs):
  175. results = super().render(rays_o, rays_d, cond, bg_coords, poses, index, dt_gamma, bg_color, perturb, force_all_rays, max_steps, T_thresh, cond_mask, eye_area_percent=eye_area_percent, **kwargs)
  176. rgb_image = results['rgb_map'].reshape([1, 256, 256, 3]).permute(0,3,1,2)
  177. sr_rgb_image = self.sr_net(rgb_image.clone())
  178. sr_rgb_image = sr_rgb_image.clamp(0,1)
  179. results['rgb_map'] = rgb_image
  180. results['sr_rgb_map'] = sr_rgb_image
  181. return results