get_file.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # -*- coding: utf-8 -*-
  2. """Download file."""
  3. import hashlib
  4. import os
  5. import shutil
  6. import sys
  7. import tarfile
  8. import time
  9. import typing
  10. import zipfile
  11. from pathlib import Path
  12. import numpy as np
  13. import six
  14. from six.moves.urllib.error import HTTPError
  15. from six.moves.urllib.error import URLError
  16. class Progbar(object):
  17. """
  18. Displays a progress bar.
  19. :param target: Total number of steps expected, None if unknown.
  20. :param width: Progress bar width on screen.
  21. :param verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
  22. :param stateful_metrics: Iterable of string names of metrics that
  23. should *not* be averaged over time. Metrics in this list
  24. will be displayed as-is. All others will be averaged
  25. by the progbar before display.
  26. :param interval: Minimum visual progress update interval (in seconds).
  27. """
  28. def __init__(
  29. self,
  30. target,
  31. width=30,
  32. verbose=1,
  33. interval=0.05,
  34. ):
  35. """Init."""
  36. self.target = target
  37. self.width = width
  38. self.verbose = verbose
  39. self.interval = interval
  40. self._dynamic_display = ((hasattr(sys.stdout,
  41. 'isatty') and sys.stdout.isatty()
  42. ) or 'ipykernel' in sys.modules)
  43. self._total_width = 0
  44. self._seen_so_far = 0
  45. self._start = time.time()
  46. self._last_update = 0
  47. def update(self, current):
  48. """Updates the progress bar."""
  49. self._seen_so_far = current
  50. now = time.time()
  51. info = ' - {0:.0f}s'.format(now - self._start)
  52. if self.verbose == 1:
  53. if (now - self._last_update < self.interval and self.target is not
  54. None and current < self.target):
  55. return
  56. prev_total_width = self._total_width
  57. if self._dynamic_display:
  58. sys.stdout.write('\b' * prev_total_width)
  59. sys.stdout.write('\r')
  60. else:
  61. sys.stdout.write('\n')
  62. if self.target is not None:
  63. numdigits = int(np.floor(np.log10(self.target))) + 1
  64. bar = '{2:{0:d}d}/{1} ['.format(
  65. numdigits, self.target, current)
  66. prog = float(current) / self.target
  67. prog_width = int(self.width * prog)
  68. if prog_width > 0:
  69. bar += ('=' * (prog_width - 1))
  70. if current < self.target:
  71. bar += '>'
  72. else:
  73. bar += '='
  74. bar += ('.' * (self.width - prog_width))
  75. bar += ']'
  76. else:
  77. bar = '{0:7d}/Unknown'.format(current)
  78. self._total_width = len(bar)
  79. sys.stdout.write(bar)
  80. if current:
  81. time_per_unit = (now - self._start) / current
  82. else:
  83. time_per_unit = 0
  84. if self.target is not None and current < self.target:
  85. eta = int(time_per_unit * (self.target - current))
  86. if eta > 3600:
  87. eta_format = ('{0:d}:{1:02d}:{2:02d}'.format(
  88. eta // 3600, (eta % 3600) // 60, eta % 60))
  89. elif eta > 60:
  90. eta_format = '{0:d}:{1:02d}'.format(eta // 60, eta % 60)
  91. else:
  92. eta_format = '{0:d}s'.format(eta)
  93. info = ' - ETA: {0}'.format(eta_format)
  94. else:
  95. if time_per_unit >= 1:
  96. info += ' {0:.0f}s/step'.format(time_per_unit)
  97. elif time_per_unit >= 1e-3:
  98. info += ' {0:.0f}ms/step'.format(time_per_unit * 1e3)
  99. else:
  100. info += ' {0:.0f}us/step'.format(time_per_unit * 1e6)
  101. self._total_width += len(info)
  102. if prev_total_width > self._total_width:
  103. info += (' ' * (prev_total_width - self._total_width))
  104. if self.target is not None and current >= self.target:
  105. info += '\n'
  106. sys.stdout.write(info)
  107. sys.stdout.flush()
  108. elif self.verbose == 2:
  109. if self.target is None or current >= self.target:
  110. info += '\n'
  111. sys.stdout.write(info)
  112. sys.stdout.flush()
  113. self._last_update = now
  114. def _extract_archive(file_path, path='.', archive_format='auto'):
  115. """
  116. Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
  117. :param file_path: path to the archive file
  118. :param path: path to extract the archive file
  119. :param archive_format: Archive format to try for extracting the file.
  120. Options are 'auto', 'tar', 'zip', and None.
  121. 'tar' includes tar, tar.gz, and tar.bz files.
  122. The default 'auto' is ['tar', 'zip'].
  123. None or an empty list will return no matches found.
  124. :return: True if a match was found and an archive extraction was completed,
  125. False otherwise.
  126. """
  127. if archive_format is None:
  128. return False
  129. if archive_format == 'auto':
  130. archive_format = ['tar', 'zip']
  131. if isinstance(archive_format, six.string_types):
  132. archive_format = [archive_format]
  133. for archive_type in archive_format:
  134. if archive_type == 'tar':
  135. open_fn = tarfile.open
  136. is_match_fn = tarfile.is_tarfile
  137. if archive_type == 'zip':
  138. open_fn = zipfile.ZipFile
  139. is_match_fn = zipfile.is_zipfile
  140. if is_match_fn(file_path):
  141. with open_fn(file_path) as archive:
  142. try:
  143. archive.extractall(path)
  144. except (tarfile.TarError, RuntimeError,
  145. KeyboardInterrupt):
  146. if os.path.exists(path):
  147. if os.path.isfile(path):
  148. os.remove(path)
  149. else:
  150. shutil.rmtree(path)
  151. raise
  152. return True
  153. return False
  154. def get_file(
  155. fname: str = None,
  156. origin: str = None,
  157. untar: bool = False,
  158. extract: bool = False,
  159. md5_hash: typing.Any = None,
  160. file_hash: typing.Any = None,
  161. hash_algorithm: str = 'auto',
  162. archive_format: str = 'auto',
  163. cache_subdir: typing.Union[Path, str] = 'data',
  164. cache_dir: typing.Union[Path, str] = 'dataset',
  165. verbose: int = 1
  166. ) -> str:
  167. """
  168. Downloads a file from a URL if it not already in the cache.
  169. By default the file at the url `origin` is downloaded to the
  170. cache_dir `~/.project/datasets`, placed in the cache_subdir `data`,
  171. and given the filename `fname`. The final location of a file
  172. `example.txt` would therefore be `~/.project/datasets/data/example.txt`.
  173. Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
  174. Passing a hash will verify the file after download. The command line
  175. programs `shasum` and `sha256sum` can compute the hash.
  176. :param fname: Name of the file. If an absolute path `/path/to/file.txt` is
  177. specified the file will be saved at that location.
  178. :param origin: Original URL of the file.
  179. :param untar: Deprecated in favor of 'extract'. Boolean, whether the file
  180. should be decompressed.
  181. :param md5_hash: Deprecated in favor of 'file_hash'. md5 hash of the file
  182. for verification.
  183. :param file_hash: The expected hash string of the file after download.
  184. The sha256 and md5 hash algorithms are both supported.
  185. :param cache_subdir: Subdirectory under the cache dir where the file is
  186. saved. If an absolute path `/path/to/folder` is specified the file
  187. will be saved at that location.
  188. :param hash_algorithm: Select the hash algorithm to verify the file.
  189. options are 'md5', 'sha256', and 'auto'. The default 'auto' detects
  190. the hash algorithm in use.
  191. :papram extract: True tries extracting the file as an Archive, like tar
  192. or zip.
  193. :param archive_format: Archive format to try for extracting the file.
  194. Options are 'auto', 'tar', 'zip', and None.
  195. 'tar' includes tar, tar.gz, and tar.bz files.
  196. The default 'auto' is ['tar', 'zip'].
  197. None or an empty list will return no matches found.
  198. :param cache_dir: Location to store cached files, when None it defaults to
  199. the [project.USER_DATA_DIR](~/.project/datasets).
  200. :param verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
  201. :return: Path to the downloaded file.
  202. """
  203. if md5_hash is not None and file_hash is None:
  204. file_hash = md5_hash
  205. hash_algorithm = 'md5'
  206. datadir_base = os.path.expanduser(cache_dir)
  207. if not os.access(datadir_base, os.W_OK):
  208. datadir_base = os.path.join('/tmp', '.project')
  209. datadir = os.path.join(datadir_base, cache_subdir)
  210. if not os.path.exists(datadir):
  211. os.makedirs(datadir)
  212. if untar:
  213. untar_fpath = os.path.join(datadir, fname)
  214. fpath = untar_fpath + '.tar.gz'
  215. else:
  216. fpath = os.path.join(datadir, fname)
  217. download = False
  218. if os.path.exists(fpath):
  219. if file_hash is not None:
  220. if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
  221. print('A local file was found, but it seems to be '
  222. 'incomplete or outdated because the file hash '
  223. 'does not match the original value of file_hash.'
  224. ' We will re-download the data.')
  225. download = True
  226. else:
  227. download = True
  228. if download:
  229. print('Downloading data from', origin)
  230. class ProgressTracker(object):
  231. progbar = None
  232. def dl_progress(count, block_size, total_size):
  233. if ProgressTracker.progbar is None:
  234. if total_size == -1:
  235. total_size = None
  236. ProgressTracker.progbar = Progbar(
  237. target=total_size, verbose=verbose)
  238. else:
  239. ProgressTracker.progbar.update(count * block_size)
  240. error_msg = 'URL fetch failure on {} : {} -- {}'
  241. try:
  242. try:
  243. from six.moves.urllib.request import urlretrieve
  244. urlretrieve(origin, fpath, dl_progress)
  245. except HTTPError as e:
  246. raise Exception(error_msg.format(origin, e.code, e.msg))
  247. except URLError as e:
  248. raise Exception(error_msg.format(origin, e.errno, e.reason))
  249. except (Exception, KeyboardInterrupt):
  250. if os.path.exists(fpath):
  251. os.remove(fpath)
  252. raise
  253. ProgressTracker.progbar = None
  254. if untar:
  255. if not os.path.exists(untar_fpath):
  256. _extract_archive(fpath, datadir, archive_format='tar')
  257. return untar_fpath
  258. if extract:
  259. _extract_archive(fpath, datadir, archive_format)
  260. return fpath
  261. def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
  262. """
  263. Validates a file against a sha256 or md5 hash.
  264. :param fpath: path to the file being validated
  265. :param file_hash: The expected hash string of the file.
  266. The sha256 and md5 hash algorithms are both supported.
  267. :param algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
  268. The default 'auto' detects the hash algorithm in use.
  269. :param chunk_size: Bytes to read at a time, important for large files.
  270. :return: Whether the file is valid.
  271. """
  272. if ((algorithm == 'sha256') or (algorithm == 'auto' and len(
  273. file_hash) == 64)):
  274. hasher = 'sha256'
  275. else:
  276. hasher = 'md5'
  277. if str(hash_file(fpath, hasher, chunk_size)) == str(file_hash):
  278. return True
  279. else:
  280. return False
  281. def hash_file(fpath, algorithm='sha256', chunk_size=65535):
  282. """
  283. Calculates a file sha256 or md5 hash.
  284. :param fpath: path to the file being validated
  285. :param algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
  286. The default 'auto' detects the hash algorithm in use.
  287. :param chunk_size: Bytes to read at a time, important for large files.
  288. :return: The file hash.
  289. """
  290. if algorithm == 'sha256':
  291. hasher = hashlib.sha256()
  292. else:
  293. hasher = hashlib.md5()
  294. with open(fpath, 'rb') as fpath_file:
  295. for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
  296. hasher.update(chunk)
  297. return hasher.hexdigest()