123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- import os
- import glob
- import tqdm
- import math
- import random
- import warnings
- import tensorboardX
- import numpy as np
- import pandas as pd
- import time
- from datetime import datetime
- import cv2
- import matplotlib.pyplot as plt
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- from torch.autograd import Function
- from torch.cuda.amp import custom_bwd, custom_fwd
- import torch.distributed as dist
- from torch.utils.data import Dataset, DataLoader
- import trimesh
- import mcubes
- from utils.commons.hparams import hparams
- from packaging import version as pver
- import imageio
- import lpips
- class _trunc_exp(Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float32) # cast to float32
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return torch.exp(x)
- @staticmethod
- @custom_bwd
- def backward(ctx, g):
- x = ctx.saved_tensors[0]
- return g * torch.exp(x.clamp(-15, 15))
- trunc_exp = _trunc_exp.apply
- # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
- def nerf_matrix_to_ngp(pose, scale=4, offset=[0, 0, 0]):
- new_pose = np.array([
- [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
- [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
- [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
- [0, 0, 0, 1],
- ], dtype=np.float32)
- return new_pose
- def custom_meshgrid(*args):
- # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
- if pver.parse(torch.__version__) < pver.parse('1.10'):
- return torch.meshgrid(*args)
- else:
- return torch.meshgrid(*args, indexing='ij')
- def get_audio_features(features, att_mode, index):
- if att_mode == 0:
- return features[[index]]
- elif att_mode == 1:
- print(hparams['smo_win_size'])
- left = index - hparams['smo_win_size']
- pad_left = 0
- if left < 0:
- pad_left = -left
- left = 0
- auds = features[left:index]
- if pad_left > 0:
- # pad may be longer than auds, so do not use zeros_like
- auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0)
- return auds
- elif att_mode == 2:
- left = index - hparams['smo_win_size']//2
- right = index + (hparams['smo_win_size']-hparams['smo_win_size']//2)
- pad_left = 0
- pad_right = 0
- if left < 0:
- pad_left = -left
- left = 0
- if right > features.shape[0]:
- pad_right = right - features.shape[0]
- right = features.shape[0]
- auds = features[left:right]
- if pad_left > 0:
- auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0)
- if pad_right > 0:
- auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16]
- return auds
- else:
- raise NotImplementedError(f'wrong att_mode: {att_mode}')
- @torch.jit.script
- def linear_to_srgb(x):
- return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
- @torch.jit.script
- def srgb_to_linear(x):
- return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
- # copied from pytorch3d
- def _angle_from_tan(
- axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
- ) -> torch.Tensor:
- """
- Extract the first or third Euler angle from the two members of
- the matrix which are positive constant times its sine and cosine.
- Args:
- axis: Axis label "X" or "Y or "Z" for the angle we are finding.
- other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
- convention.
- data: Rotation matrices as tensor of shape (..., 3, 3).
- horizontal: Whether we are looking for the angle for the third axis,
- which means the relevant entries are in the same row of the
- rotation matrix. If not, they are in the same column.
- tait_bryan: Whether the first and third axes in the convention differ.
- Returns:
- Euler Angles in radians for each matrix in data as a tensor
- of shape (...).
- """
- i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
- if horizontal:
- i2, i1 = i1, i2
- even = (axis + other_axis) in ["XY", "YZ", "ZX"]
- if horizontal == even:
- return torch.atan2(data[..., i1], data[..., i2])
- if tait_bryan:
- return torch.atan2(-data[..., i2], data[..., i1])
- return torch.atan2(data[..., i2], -data[..., i1])
- def _index_from_letter(letter: str) -> int:
- if letter == "X":
- return 0
- if letter == "Y":
- return 1
- if letter == "Z":
- return 2
- raise ValueError("letter must be either X, Y or Z.")
- def matrix_to_euler_angles(matrix: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
- """
- Convert rotations given as rotation matrices to Euler angles in radians.
- Args:
- matrix: Rotation matrices as tensor of shape (..., 3, 3).
- convention: Convention string of three uppercase letters.
- Returns:
- Euler angles in radians as tensor of shape (..., 3).
- """
- # if len(convention) != 3:
- # raise ValueError("Convention must have 3 letters.")
- # if convention[1] in (convention[0], convention[2]):
- # raise ValueError(f"Invalid convention {convention}.")
- # for letter in convention:
- # if letter not in ("X", "Y", "Z"):
- # raise ValueError(f"Invalid letter {letter} in convention string.")
- # if matrix.size(-1) != 3 or matrix.size(-2) != 3:
- # raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
- i0 = _index_from_letter(convention[0])
- i2 = _index_from_letter(convention[2])
- tait_bryan = i0 != i2
- if tait_bryan:
- central_angle = torch.asin(
- matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
- )
- else:
- central_angle = torch.acos(matrix[..., i0, i0])
- o = (
- _angle_from_tan(
- convention[0], convention[1], matrix[..., i2], False, tait_bryan
- ),
- central_angle,
- _angle_from_tan(
- convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
- ),
- )
- return torch.stack(o, -1)
- def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
- """
- Return the rotation matrices for one of the rotations about an axis
- of which Euler angles describe, for each value of the angle given.
- Args:
- axis: Axis label "X" or "Y or "Z".
- angle: any shape tensor of Euler angles in radians
- Returns:
- Rotation matrices as tensor of shape (..., 3, 3).
- """
- cos = torch.cos(angle)
- sin = torch.sin(angle)
- one = torch.ones_like(angle)
- zero = torch.zeros_like(angle)
- if axis == "X":
- R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
- elif axis == "Y":
- R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
- elif axis == "Z":
- R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
- else:
- raise ValueError("letter must be either X, Y or Z.")
- return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
- def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str='XYZ') -> torch.Tensor:
- """
- Convert rotations given as Euler angles in radians to rotation matrices.
- Args:
- euler_angles: Euler angles in radians as tensor of shape (..., 3).
- convention: Convention string of three uppercase letters from
- {"X", "Y", and "Z"}.
- Returns:
- Rotation matrices as tensor of shape (..., 3, 3).
- """
- # print(euler_angles, euler_angles.dtype)
- if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
- raise ValueError("Invalid input euler angles.")
- if len(convention) != 3:
- raise ValueError("Convention must have 3 letters.")
- if convention[1] in (convention[0], convention[2]):
- raise ValueError(f"Invalid convention {convention}.")
- for letter in convention:
- if letter not in ("X", "Y", "Z"):
- raise ValueError(f"Invalid letter {letter} in convention string.")
- matrices = [
- _axis_angle_rotation(c, e)
- for c, e in zip(convention, torch.unbind(euler_angles, -1))
- ]
-
- return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
- def convert_poses(poses):
- # poses: [B, 4, 4]
- # return [B, 3], 4 rot, 3 trans
- out = torch.empty(poses.shape[0], 6, dtype=torch.float32, device=poses.device)
- out[:, :3] = matrix_to_euler_angles(poses[:, :3, :3])
- out[:, 3:] = poses[:, :3, 3]
- return out
- def get_bg_coords(H, W, device):
- X = torch.arange(H, device=device) / (H - 1) * 2 - 1 # in [-1, 1]
- Y = torch.arange(W, device=device) / (W - 1) * 2 - 1 # in [-1, 1]
- xs, ys = custom_meshgrid(X, Y)
- bg_coords = torch.cat([xs.reshape(-1, 1), ys.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, H*W, 2], in [-1, 1]
- return bg_coords
- def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None):
- ''' get rays
- Args:
- poses: [B, 4, 4], cam2world
- intrinsics: [4]
- H, W, N: int
- Returns:
- rays_o, rays_d: [B, N, 3]
- inds: [B, N]
- '''
- device = poses.device
- B = poses.shape[0]
- fx, fy, cx, cy = intrinsics
- if rect is not None:
- xmin, xmax, ymin, ymax = rect
- N = (xmax - xmin) * (ymax - ymin)
- i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float
- i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
- j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
- results = {}
- if N > 0:
- N = min(N, H*W)
- if patch_size > 1:
- # random sample left-top cores.
- # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas.
- num_patch = N // (patch_size ** 2)
- inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
- inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
- inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
- # create meshgrid for each patch
- pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
- offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]
- inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
- inds = inds.view(-1, 2) # [N, 2]
- inds = inds[:, 0] * W + inds[:, 1] # [N], flatten
- inds = inds.expand([B, N])
-
- # only get rays in the specified rect
- elif rect is not None:
- # assert B == 1
- mask = torch.zeros(H, W, dtype=torch.bool, device=device)
- xmin, xmax, ymin, ymax = rect
- mask[xmin:xmax, ymin:ymax] = 1
- inds = torch.where(mask.view(-1))[0] # [nzn]
- inds = inds.unsqueeze(0) # [1, N]
- else:
- inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
- inds = inds.expand([B, N])
- i = torch.gather(i, -1, inds)
- j = torch.gather(j, -1, inds)
- else:
- inds = torch.arange(H*W, device=device).expand([B, H*W])
-
- results['i'] = i
- results['j'] = j
- results['inds'] = inds
- zs = torch.ones_like(i)
- xs = (i - cx) / fx * zs
- ys = (j - cy) / fy * zs
- directions = torch.stack((xs, ys, zs), dim=-1)
- directions = directions / torch.norm(directions, dim=-1, keepdim=True)
- rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
- rays_o = poses[..., :3, 3] # [B, 3]
- rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
- results['rays_o'] = rays_o #.clone()
- results['rays_d'] = rays_d
- return results
- def seed_everything(seed):
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- #torch.backends.cudnn.deterministic = True
- #torch.backends.cudnn.benchmark = True
- def torch_vis_2d(x, renormalize=False):
- # x: [3, H, W] or [1, H, W] or [H, W]
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
-
- if isinstance(x, torch.Tensor):
- if len(x.shape) == 3:
- x = x.permute(1,2,0).squeeze()
- x = x.detach().cpu().numpy()
-
- print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
-
- x = x.astype(np.float32)
-
- # renormalize
- if renormalize:
- x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
- plt.imshow(x)
- plt.show()
- def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
- X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
- Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
- Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)
- u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
- with torch.no_grad():
- for xi, xs in enumerate(X):
- for yi, ys in enumerate(Y):
- for zi, zs in enumerate(Z):
- xx, yy, zz = custom_meshgrid(xs, ys, zs)
- pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
- val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
- u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val
- return u
- def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
- #print('threshold: {}'.format(threshold))
- u = extract_fields(bound_min, bound_max, resolution, query_func)
- #print(u.shape, u.max(), u.min(), np.percentile(u, 50))
-
- vertices, triangles = mcubes.marching_cubes(u, threshold)
- b_max_np = bound_max.detach().cpu().numpy()
- b_min_np = bound_min.detach().cpu().numpy()
- vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
- return vertices, triangles
|