file_ops.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import annotations
  2. import os
  3. import numpy as np
  4. import validators
  5. from typing import TYPE_CHECKING
  6. if TYPE_CHECKING:
  7. from typing import Iterable
  8. def add_extension_if_not_present(file_name: str, extension: str) -> str:
  9. # This could conceivably be smarter about handling existing differing extensions
  10. if(file_name[-len(extension):] != extension):
  11. return file_name + extension
  12. else:
  13. return file_name
  14. def guarantee_existence(path: str) -> str:
  15. if not os.path.exists(path):
  16. os.makedirs(path)
  17. return os.path.abspath(path)
  18. def find_file(
  19. file_name: str,
  20. directories: Iterable[str] | None = None,
  21. extensions: Iterable[str] | None = None
  22. ) -> str:
  23. # Check if this is a file online first, and if so, download
  24. # it to a temporary directory
  25. if validators.url(file_name):
  26. import urllib.request
  27. from manimlib.utils.directories import get_downloads_dir
  28. stem, name = os.path.split(file_name)
  29. folder = get_downloads_dir()
  30. path = os.path.join(folder, name)
  31. urllib.request.urlretrieve(file_name, path)
  32. return path
  33. # Check if what was passed in is already a valid path to a file
  34. if os.path.exists(file_name):
  35. return file_name
  36. # Otherwise look in local file system
  37. directories = directories or [""]
  38. extensions = extensions or [""]
  39. possible_paths = (
  40. os.path.join(directory, file_name + extension)
  41. for directory in directories
  42. for extension in extensions
  43. )
  44. for path in possible_paths:
  45. if os.path.exists(path):
  46. return path
  47. raise IOError(f"{file_name} not Found")
  48. def get_sorted_integer_files(
  49. directory: str,
  50. min_index: float = 0,
  51. max_index: float = np.inf,
  52. remove_non_integer_files: bool = False,
  53. remove_indices_greater_than: float | None = None,
  54. extension: str | None = None,
  55. ) -> list[str]:
  56. indexed_files = []
  57. for file in os.listdir(directory):
  58. if '.' in file:
  59. index_str = file[:file.index('.')]
  60. else:
  61. index_str = file
  62. full_path = os.path.join(directory, file)
  63. if index_str.isdigit():
  64. index = int(index_str)
  65. if remove_indices_greater_than is not None:
  66. if index > remove_indices_greater_than:
  67. os.remove(full_path)
  68. continue
  69. if extension is not None and not file.endswith(extension):
  70. continue
  71. if index >= min_index and index < max_index:
  72. indexed_files.append((index, file))
  73. elif remove_non_integer_files:
  74. os.remove(full_path)
  75. indexed_files.sort(key=lambda p: p[0])
  76. return list(map(lambda p: os.path.join(directory, p[1]), indexed_files))