You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

download.py 5.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """Utility of downloading"""
  2. import bz2
  3. import gzip
  4. import hashlib
  5. import logging
  6. import os
  7. import ssl
  8. import tarfile
  9. import urllib
  10. import urllib.error
  11. import urllib.request
  12. import zipfile
  13. from copy import deepcopy
  14. from typing import Optional
  15. from tqdm import tqdm
  16. from .path import detect_file_type
  17. _logger = logging.getLogger(__name__)
  18. # The default root directory where we save downloaded files.
  19. # Use Get/Set to R/W this variable.
  20. _DEFAULT_DOWNLOAD_ROOT = os.path.join(os.path.expanduser("~"), ".mindspore")
  21. def get_default_download_root():
  22. return deepcopy(_DEFAULT_DOWNLOAD_ROOT)
  23. def set_default_download_root(path):
  24. global _DEFAULT_DOWNLOAD_ROOT
  25. _DEFAULT_DOWNLOAD_ROOT = path
  26. class DownLoad:
  27. """Base utility class for downloading."""
  28. USER_AGENT: str = (
  29. "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) "
  30. "Chrome/92.0.4515.131 Safari/537.36"
  31. )
  32. @staticmethod
  33. def calculate_md5(file_path: str, chunk_size: int = 1024 * 1024) -> str:
  34. """Calculate md5 value."""
  35. md5 = hashlib.md5()
  36. with open(file_path, "rb") as fp:
  37. for chunk in iter(lambda: fp.read(chunk_size), b""):
  38. md5.update(chunk)
  39. return md5.hexdigest()
  40. def check_md5(self, file_path: str, md5: Optional[str] = None) -> bool:
  41. """Check md5 value."""
  42. return md5 == self.calculate_md5(file_path)
  43. @staticmethod
  44. def extract_tar(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None:
  45. """Extract tar format file."""
  46. with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
  47. tar.extractall(to_path)
  48. @staticmethod
  49. def extract_zip(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None:
  50. """Extract zip format file."""
  51. compression_mode = zipfile.ZIP_BZIP2 if compression else zipfile.ZIP_STORED
  52. with zipfile.ZipFile(from_path, "r", compression=compression_mode) as zip_file:
  53. zip_file.extractall(to_path)
  54. def extract_archive(self, from_path: str, to_path: str = None) -> str:
  55. """Extract and archive from path to path."""
  56. archive_extractors = {
  57. ".tar": self.extract_tar,
  58. ".zip": self.extract_zip,
  59. }
  60. compress_file_open = {
  61. ".bz2": bz2.open,
  62. ".gz": gzip.open,
  63. }
  64. if not to_path:
  65. to_path = os.path.dirname(from_path)
  66. suffix, archive_type, compression = detect_file_type(from_path) # pylint: disable=unused-variable
  67. if not archive_type:
  68. to_path = from_path.replace(suffix, "")
  69. compress = compress_file_open[compression]
  70. with compress(from_path, "rb") as rf, open(to_path, "wb") as wf:
  71. wf.write(rf.read())
  72. return to_path
  73. extractor = archive_extractors[archive_type]
  74. extractor(from_path, to_path, compression)
  75. return to_path
  76. def download_file(self, url: str, file_path: str, chunk_size: int = 1024):
  77. """Download a file."""
  78. # Define request headers.
  79. headers = {"User-Agent": self.USER_AGENT}
  80. _logger.info(f"Downloading from {url} to {file_path} ...")
  81. with open(file_path, "wb") as f:
  82. request = urllib.request.Request(url, headers=headers)
  83. with urllib.request.urlopen(request) as response:
  84. with tqdm(total=response.length, unit="B") as pbar:
  85. for chunk in iter(lambda: response.read(chunk_size), b""):
  86. if not chunk:
  87. break
  88. pbar.update(chunk_size)
  89. f.write(chunk)
  90. def download_url(
  91. self,
  92. url: str,
  93. path: Optional[str] = None,
  94. filename: Optional[str] = None,
  95. md5: Optional[str] = None,
  96. ) -> None:
  97. """Download a file from a url and place it in root."""
  98. if path is None:
  99. path = get_default_download_root()
  100. path = os.path.expanduser(path)
  101. os.makedirs(path, exist_ok=True)
  102. if not filename:
  103. filename = os.path.basename(url)
  104. file_path = os.path.join(path, filename)
  105. # Check if the file is exists.
  106. if os.path.isfile(file_path):
  107. if not md5 or self.check_md5(file_path, md5):
  108. return
  109. # Download the file.
  110. try:
  111. self.download_file(url, file_path)
  112. except (urllib.error.URLError, IOError) as e:
  113. if url.startswith("https"):
  114. url = url.replace("https", "http")
  115. try:
  116. self.download_file(url, file_path)
  117. except (urllib.error.URLError, IOError):
  118. # pylint: disable=protected-access
  119. ssl._create_default_https_context = ssl._create_unverified_context
  120. self.download_file(url, file_path)
  121. ssl._create_default_https_context = ssl.create_default_context
  122. else:
  123. raise e
  124. def download_and_extract_archive(
  125. self,
  126. url: str,
  127. download_path: Optional[str] = None,
  128. extract_path: Optional[str] = None,
  129. filename: Optional[str] = None,
  130. md5: Optional[str] = None,
  131. remove_finished: bool = False,
  132. ) -> None:
  133. """Download and extract archive."""
  134. if download_path is None:
  135. download_path = get_default_download_root()
  136. download_path = os.path.expanduser(download_path)
  137. if not filename:
  138. filename = os.path.basename(url)
  139. self.download_url(url, download_path, filename, md5)
  140. archive = os.path.join(download_path, filename)
  141. self.extract_archive(archive, extract_path)
  142. if remove_finished:
  143. os.remove(archive)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN