utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. import os
  2. import glob
  3. import tqdm
  4. import math
  5. import random
  6. import warnings
  7. import tensorboardX
  8. import numpy as np
  9. import pandas as pd
  10. import time
  11. from datetime import datetime
  12. import cv2
  13. import matplotlib.pyplot as plt
  14. import torch
  15. import torch.nn as nn
  16. import torch.optim as optim
  17. import torch.nn.functional as F
  18. from torch.autograd import Function
  19. from torch.cuda.amp import custom_bwd, custom_fwd
  20. import torch.distributed as dist
  21. from torch.utils.data import Dataset, DataLoader
  22. import trimesh
  23. import mcubes
  24. from utils.commons.hparams import hparams
  25. from packaging import version as pver
  26. import imageio
  27. import lpips
  28. class _trunc_exp(Function):
  29. @staticmethod
  30. @custom_fwd(cast_inputs=torch.float32) # cast to float32
  31. def forward(ctx, x):
  32. ctx.save_for_backward(x)
  33. return torch.exp(x)
  34. @staticmethod
  35. @custom_bwd
  36. def backward(ctx, g):
  37. x = ctx.saved_tensors[0]
  38. return g * torch.exp(x.clamp(-15, 15))
  39. trunc_exp = _trunc_exp.apply
  40. # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
  41. def nerf_matrix_to_ngp(pose, scale=4, offset=[0, 0, 0]):
  42. new_pose = np.array([
  43. [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
  44. [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
  45. [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
  46. [0, 0, 0, 1],
  47. ], dtype=np.float32)
  48. return new_pose
  49. def custom_meshgrid(*args):
  50. # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
  51. if pver.parse(torch.__version__) < pver.parse('1.10'):
  52. return torch.meshgrid(*args)
  53. else:
  54. return torch.meshgrid(*args, indexing='ij')
  55. def get_audio_features(features, att_mode, index):
  56. if att_mode == 0:
  57. return features[[index]]
  58. elif att_mode == 1:
  59. print(hparams['smo_win_size'])
  60. left = index - hparams['smo_win_size']
  61. pad_left = 0
  62. if left < 0:
  63. pad_left = -left
  64. left = 0
  65. auds = features[left:index]
  66. if pad_left > 0:
  67. # pad may be longer than auds, so do not use zeros_like
  68. auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0)
  69. return auds
  70. elif att_mode == 2:
  71. left = index - hparams['smo_win_size']//2
  72. right = index + (hparams['smo_win_size']-hparams['smo_win_size']//2)
  73. pad_left = 0
  74. pad_right = 0
  75. if left < 0:
  76. pad_left = -left
  77. left = 0
  78. if right > features.shape[0]:
  79. pad_right = right - features.shape[0]
  80. right = features.shape[0]
  81. auds = features[left:right]
  82. if pad_left > 0:
  83. auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0)
  84. if pad_right > 0:
  85. auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16]
  86. return auds
  87. else:
  88. raise NotImplementedError(f'wrong att_mode: {att_mode}')
  89. @torch.jit.script
  90. def linear_to_srgb(x):
  91. return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
  92. @torch.jit.script
  93. def srgb_to_linear(x):
  94. return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
  95. # copied from pytorch3d
  96. def _angle_from_tan(
  97. axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
  98. ) -> torch.Tensor:
  99. """
  100. Extract the first or third Euler angle from the two members of
  101. the matrix which are positive constant times its sine and cosine.
  102. Args:
  103. axis: Axis label "X" or "Y or "Z" for the angle we are finding.
  104. other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
  105. convention.
  106. data: Rotation matrices as tensor of shape (..., 3, 3).
  107. horizontal: Whether we are looking for the angle for the third axis,
  108. which means the relevant entries are in the same row of the
  109. rotation matrix. If not, they are in the same column.
  110. tait_bryan: Whether the first and third axes in the convention differ.
  111. Returns:
  112. Euler Angles in radians for each matrix in data as a tensor
  113. of shape (...).
  114. """
  115. i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
  116. if horizontal:
  117. i2, i1 = i1, i2
  118. even = (axis + other_axis) in ["XY", "YZ", "ZX"]
  119. if horizontal == even:
  120. return torch.atan2(data[..., i1], data[..., i2])
  121. if tait_bryan:
  122. return torch.atan2(-data[..., i2], data[..., i1])
  123. return torch.atan2(data[..., i2], -data[..., i1])
  124. def _index_from_letter(letter: str) -> int:
  125. if letter == "X":
  126. return 0
  127. if letter == "Y":
  128. return 1
  129. if letter == "Z":
  130. return 2
  131. raise ValueError("letter must be either X, Y or Z.")
  132. def matrix_to_euler_angles(matrix: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
  133. """
  134. Convert rotations given as rotation matrices to Euler angles in radians.
  135. Args:
  136. matrix: Rotation matrices as tensor of shape (..., 3, 3).
  137. convention: Convention string of three uppercase letters.
  138. Returns:
  139. Euler angles in radians as tensor of shape (..., 3).
  140. """
  141. # if len(convention) != 3:
  142. # raise ValueError("Convention must have 3 letters.")
  143. # if convention[1] in (convention[0], convention[2]):
  144. # raise ValueError(f"Invalid convention {convention}.")
  145. # for letter in convention:
  146. # if letter not in ("X", "Y", "Z"):
  147. # raise ValueError(f"Invalid letter {letter} in convention string.")
  148. # if matrix.size(-1) != 3 or matrix.size(-2) != 3:
  149. # raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
  150. i0 = _index_from_letter(convention[0])
  151. i2 = _index_from_letter(convention[2])
  152. tait_bryan = i0 != i2
  153. if tait_bryan:
  154. central_angle = torch.asin(
  155. matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
  156. )
  157. else:
  158. central_angle = torch.acos(matrix[..., i0, i0])
  159. o = (
  160. _angle_from_tan(
  161. convention[0], convention[1], matrix[..., i2], False, tait_bryan
  162. ),
  163. central_angle,
  164. _angle_from_tan(
  165. convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
  166. ),
  167. )
  168. return torch.stack(o, -1)
  169. def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
  170. """
  171. Return the rotation matrices for one of the rotations about an axis
  172. of which Euler angles describe, for each value of the angle given.
  173. Args:
  174. axis: Axis label "X" or "Y or "Z".
  175. angle: any shape tensor of Euler angles in radians
  176. Returns:
  177. Rotation matrices as tensor of shape (..., 3, 3).
  178. """
  179. cos = torch.cos(angle)
  180. sin = torch.sin(angle)
  181. one = torch.ones_like(angle)
  182. zero = torch.zeros_like(angle)
  183. if axis == "X":
  184. R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
  185. elif axis == "Y":
  186. R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
  187. elif axis == "Z":
  188. R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
  189. else:
  190. raise ValueError("letter must be either X, Y or Z.")
  191. return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
  192. def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str='XYZ') -> torch.Tensor:
  193. """
  194. Convert rotations given as Euler angles in radians to rotation matrices.
  195. Args:
  196. euler_angles: Euler angles in radians as tensor of shape (..., 3).
  197. convention: Convention string of three uppercase letters from
  198. {"X", "Y", and "Z"}.
  199. Returns:
  200. Rotation matrices as tensor of shape (..., 3, 3).
  201. """
  202. # print(euler_angles, euler_angles.dtype)
  203. if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
  204. raise ValueError("Invalid input euler angles.")
  205. if len(convention) != 3:
  206. raise ValueError("Convention must have 3 letters.")
  207. if convention[1] in (convention[0], convention[2]):
  208. raise ValueError(f"Invalid convention {convention}.")
  209. for letter in convention:
  210. if letter not in ("X", "Y", "Z"):
  211. raise ValueError(f"Invalid letter {letter} in convention string.")
  212. matrices = [
  213. _axis_angle_rotation(c, e)
  214. for c, e in zip(convention, torch.unbind(euler_angles, -1))
  215. ]
  216. return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
  217. def convert_poses(poses):
  218. # poses: [B, 4, 4]
  219. # return [B, 3], 4 rot, 3 trans
  220. out = torch.empty(poses.shape[0], 6, dtype=torch.float32, device=poses.device)
  221. out[:, :3] = matrix_to_euler_angles(poses[:, :3, :3])
  222. out[:, 3:] = poses[:, :3, 3]
  223. return out
  224. def get_bg_coords(H, W, device):
  225. X = torch.arange(H, device=device) / (H - 1) * 2 - 1 # in [-1, 1]
  226. Y = torch.arange(W, device=device) / (W - 1) * 2 - 1 # in [-1, 1]
  227. xs, ys = custom_meshgrid(X, Y)
  228. bg_coords = torch.cat([xs.reshape(-1, 1), ys.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, H*W, 2], in [-1, 1]
  229. return bg_coords
  230. def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None):
  231. ''' get rays
  232. Args:
  233. poses: [B, 4, 4], cam2world
  234. intrinsics: [4]
  235. H, W, N: int
  236. Returns:
  237. rays_o, rays_d: [B, N, 3]
  238. inds: [B, N]
  239. '''
  240. device = poses.device
  241. B = poses.shape[0]
  242. fx, fy, cx, cy = intrinsics
  243. if rect is not None:
  244. xmin, xmax, ymin, ymax = rect
  245. N = (xmax - xmin) * (ymax - ymin)
  246. i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float
  247. i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
  248. j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
  249. results = {}
  250. if N > 0:
  251. N = min(N, H*W)
  252. if patch_size > 1:
  253. # random sample left-top cores.
  254. # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas.
  255. num_patch = N // (patch_size ** 2)
  256. inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
  257. inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
  258. inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
  259. # create meshgrid for each patch
  260. pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
  261. offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]
  262. inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
  263. inds = inds.view(-1, 2) # [N, 2]
  264. inds = inds[:, 0] * W + inds[:, 1] # [N], flatten
  265. inds = inds.expand([B, N])
  266. # only get rays in the specified rect
  267. elif rect is not None:
  268. # assert B == 1
  269. mask = torch.zeros(H, W, dtype=torch.bool, device=device)
  270. xmin, xmax, ymin, ymax = rect
  271. mask[xmin:xmax, ymin:ymax] = 1
  272. inds = torch.where(mask.view(-1))[0] # [nzn]
  273. inds = inds.unsqueeze(0) # [1, N]
  274. else:
  275. inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
  276. inds = inds.expand([B, N])
  277. i = torch.gather(i, -1, inds)
  278. j = torch.gather(j, -1, inds)
  279. else:
  280. inds = torch.arange(H*W, device=device).expand([B, H*W])
  281. results['i'] = i
  282. results['j'] = j
  283. results['inds'] = inds
  284. zs = torch.ones_like(i)
  285. xs = (i - cx) / fx * zs
  286. ys = (j - cy) / fy * zs
  287. directions = torch.stack((xs, ys, zs), dim=-1)
  288. directions = directions / torch.norm(directions, dim=-1, keepdim=True)
  289. rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
  290. rays_o = poses[..., :3, 3] # [B, 3]
  291. rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
  292. results['rays_o'] = rays_o #.clone()
  293. results['rays_d'] = rays_d
  294. return results
  295. def seed_everything(seed):
  296. random.seed(seed)
  297. os.environ['PYTHONHASHSEED'] = str(seed)
  298. np.random.seed(seed)
  299. torch.manual_seed(seed)
  300. torch.cuda.manual_seed(seed)
  301. #torch.backends.cudnn.deterministic = True
  302. #torch.backends.cudnn.benchmark = True
  303. def torch_vis_2d(x, renormalize=False):
  304. # x: [3, H, W] or [1, H, W] or [H, W]
  305. import matplotlib.pyplot as plt
  306. import numpy as np
  307. import torch
  308. if isinstance(x, torch.Tensor):
  309. if len(x.shape) == 3:
  310. x = x.permute(1,2,0).squeeze()
  311. x = x.detach().cpu().numpy()
  312. print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
  313. x = x.astype(np.float32)
  314. # renormalize
  315. if renormalize:
  316. x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
  317. plt.imshow(x)
  318. plt.show()
  319. def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
  320. X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
  321. Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
  322. Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)
  323. u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
  324. with torch.no_grad():
  325. for xi, xs in enumerate(X):
  326. for yi, ys in enumerate(Y):
  327. for zi, zs in enumerate(Z):
  328. xx, yy, zz = custom_meshgrid(xs, ys, zs)
  329. pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
  330. val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
  331. u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val
  332. return u
  333. def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
  334. #print('threshold: {}'.format(threshold))
  335. u = extract_fields(bound_min, bound_max, resolution, query_func)
  336. #print(u.shape, u.max(), u.min(), np.percentile(u, 50))
  337. vertices, triangles = mcubes.marching_cubes(u, threshold)
  338. b_max_np = bound_max.detach().cpu().numpy()
  339. b_min_np = bound_min.detach().cpu().numpy()
  340. vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
  341. return vertices, triangles