mesh_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
  3. #
  4. # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
  5. # property and proprietary rights in and to this material, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this material and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION or
  9. # its affiliates is strictly prohibited.
  10. """
  11. Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.)
  12. Takes as input an .mrc file and extracts a mesh.
  13. Ex.
  14. python shape_utils.py my_shape.mrc
  15. Ex.
  16. python shape_utils.py myshapes_directory --level=12
  17. """
  18. import time
  19. import plyfile
  20. import glob
  21. import logging
  22. import numpy as np
  23. import os
  24. import random
  25. import torch
  26. import torch.utils.data
  27. import trimesh
  28. import skimage.measure
  29. import argparse
  30. import mrcfile
  31. from tqdm import tqdm
  32. def convert_sdf_samples_to_ply(
  33. numpy_3d_sdf_tensor,
  34. voxel_grid_origin,
  35. voxel_size,
  36. ply_filename_out,
  37. offset=None,
  38. scale=None,
  39. level=0.0
  40. ):
  41. """
  42. Convert sdf samples to .ply
  43. :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
  44. :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
  45. :voxel_size: float, the size of the voxels
  46. :ply_filename_out: string, path of the filename to save to
  47. This function adapted from: https://github.com/RobotLocomotion/spartan
  48. """
  49. start_time = time.time()
  50. verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
  51. # try:
  52. verts, faces, normals, values = skimage.measure.marching_cubes(
  53. numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3
  54. )
  55. # except:
  56. # pass
  57. # transform from voxel coordinates to camera coordinates
  58. # note x and y are flipped in the output of marching_cubes
  59. mesh_points = np.zeros_like(verts)
  60. mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
  61. mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
  62. mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
  63. # apply additional offset and scale
  64. if scale is not None:
  65. mesh_points = mesh_points / scale
  66. if offset is not None:
  67. mesh_points = mesh_points - offset
  68. # try writing to the ply file
  69. num_verts = verts.shape[0]
  70. num_faces = faces.shape[0]
  71. verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
  72. for i in range(0, num_verts):
  73. verts_tuple[i] = tuple(mesh_points[i, :])
  74. faces_building = []
  75. for i in range(0, num_faces):
  76. faces_building.append(((faces[i, :].tolist(),)))
  77. faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
  78. el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
  79. el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
  80. ply_data = plyfile.PlyData([el_verts, el_faces])
  81. ply_data.write(ply_filename_out)
  82. print(f"wrote to {ply_filename_out}")
  83. def convert_mrc(input_filename, output_filename, isosurface_level=1):
  84. with mrcfile.open(input_filename) as mrc:
  85. convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level)
  86. if __name__ == '__main__':
  87. start_time = time.time()
  88. parser = argparse.ArgumentParser()
  89. parser.add_argument('input_mrc_path')
  90. parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes")
  91. args = parser.parse_args()
  92. if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply':
  93. output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply'
  94. convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1)
  95. print(f"{time.time() - start_time:02f} s")
  96. else:
  97. assert os.path.isdir(args.input_mrc_path)
  98. for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))):
  99. output_obj_path = mrc_path.split('.mrc')[0] + '.ply'
  100. convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level)