| @@ -1,177 +0,0 @@ | |||
| """Utility of downloading""" | |||
| import bz2 | |||
| import gzip | |||
| import hashlib | |||
| import logging | |||
| import os | |||
| import ssl | |||
| import tarfile | |||
| import urllib | |||
| import urllib.error | |||
| import urllib.request | |||
| import zipfile | |||
| from copy import deepcopy | |||
| from typing import Optional | |||
| from tqdm import tqdm | |||
| from .path import detect_file_type | |||
| _logger = logging.getLogger(__name__) | |||
| # The default root directory where we save downloaded files. | |||
| # Use Get/Set to R/W this variable. | |||
| _DEFAULT_DOWNLOAD_ROOT = os.path.join(os.path.expanduser("~"), ".mindspore") | |||
| def get_default_download_root(): | |||
| return deepcopy(_DEFAULT_DOWNLOAD_ROOT) | |||
| def set_default_download_root(path): | |||
| global _DEFAULT_DOWNLOAD_ROOT | |||
| _DEFAULT_DOWNLOAD_ROOT = path | |||
| class DownLoad: | |||
| """Base utility class for downloading.""" | |||
| USER_AGENT: str = ( | |||
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " | |||
| "Chrome/92.0.4515.131 Safari/537.36" | |||
| ) | |||
| @staticmethod | |||
| def calculate_md5(file_path: str, chunk_size: int = 1024 * 1024) -> str: | |||
| """Calculate md5 value.""" | |||
| md5 = hashlib.md5() | |||
| with open(file_path, "rb") as fp: | |||
| for chunk in iter(lambda: fp.read(chunk_size), b""): | |||
| md5.update(chunk) | |||
| return md5.hexdigest() | |||
| def check_md5(self, file_path: str, md5: Optional[str] = None) -> bool: | |||
| """Check md5 value.""" | |||
| return md5 == self.calculate_md5(file_path) | |||
| @staticmethod | |||
| def extract_tar(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None: | |||
| """Extract tar format file.""" | |||
| with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: | |||
| tar.extractall(to_path) | |||
| @staticmethod | |||
| def extract_zip(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None: | |||
| """Extract zip format file.""" | |||
| compression_mode = zipfile.ZIP_BZIP2 if compression else zipfile.ZIP_STORED | |||
| with zipfile.ZipFile(from_path, "r", compression=compression_mode) as zip_file: | |||
| zip_file.extractall(to_path) | |||
| def extract_archive(self, from_path: str, to_path: str = None) -> str: | |||
| """Extract and archive from path to path.""" | |||
| archive_extractors = { | |||
| ".tar": self.extract_tar, | |||
| ".zip": self.extract_zip, | |||
| } | |||
| compress_file_open = { | |||
| ".bz2": bz2.open, | |||
| ".gz": gzip.open, | |||
| } | |||
| if not to_path: | |||
| to_path = os.path.dirname(from_path) | |||
| suffix, archive_type, compression = detect_file_type(from_path) # pylint: disable=unused-variable | |||
| if not archive_type: | |||
| to_path = from_path.replace(suffix, "") | |||
| compress = compress_file_open[compression] | |||
| with compress(from_path, "rb") as rf, open(to_path, "wb") as wf: | |||
| wf.write(rf.read()) | |||
| return to_path | |||
| extractor = archive_extractors[archive_type] | |||
| extractor(from_path, to_path, compression) | |||
| return to_path | |||
| def download_file(self, url: str, file_path: str, chunk_size: int = 1024): | |||
| """Download a file.""" | |||
| # Define request headers. | |||
| headers = {"User-Agent": self.USER_AGENT} | |||
| _logger.info(f"Downloading from {url} to {file_path} ...") | |||
| with open(file_path, "wb") as f: | |||
| request = urllib.request.Request(url, headers=headers) | |||
| with urllib.request.urlopen(request) as response: | |||
| with tqdm(total=response.length, unit="B") as pbar: | |||
| for chunk in iter(lambda: response.read(chunk_size), b""): | |||
| if not chunk: | |||
| break | |||
| pbar.update(chunk_size) | |||
| f.write(chunk) | |||
| def download_url( | |||
| self, | |||
| url: str, | |||
| path: Optional[str] = None, | |||
| filename: Optional[str] = None, | |||
| md5: Optional[str] = None, | |||
| ) -> None: | |||
| """Download a file from a url and place it in root.""" | |||
| if path is None: | |||
| path = get_default_download_root() | |||
| path = os.path.expanduser(path) | |||
| os.makedirs(path, exist_ok=True) | |||
| if not filename: | |||
| filename = os.path.basename(url) | |||
| file_path = os.path.join(path, filename) | |||
| # Check if the file is exists. | |||
| if os.path.isfile(file_path): | |||
| if not md5 or self.check_md5(file_path, md5): | |||
| return | |||
| # Download the file. | |||
| try: | |||
| self.download_file(url, file_path) | |||
| except (urllib.error.URLError, IOError) as e: | |||
| if url.startswith("https"): | |||
| url = url.replace("https", "http") | |||
| try: | |||
| self.download_file(url, file_path) | |||
| except (urllib.error.URLError, IOError): | |||
| # pylint: disable=protected-access | |||
| ssl._create_default_https_context = ssl._create_unverified_context | |||
| self.download_file(url, file_path) | |||
| ssl._create_default_https_context = ssl.create_default_context | |||
| else: | |||
| raise e | |||
| def download_and_extract_archive( | |||
| self, | |||
| url: str, | |||
| download_path: Optional[str] = None, | |||
| extract_path: Optional[str] = None, | |||
| filename: Optional[str] = None, | |||
| md5: Optional[str] = None, | |||
| remove_finished: bool = False, | |||
| ) -> None: | |||
| """Download and extract archive.""" | |||
| if download_path is None: | |||
| download_path = get_default_download_root() | |||
| download_path = os.path.expanduser(download_path) | |||
| if not filename: | |||
| filename = os.path.basename(url) | |||
| self.download_url(url, download_path, filename, md5) | |||
| archive = os.path.join(download_path, filename) | |||
| self.extract_archive(archive, extract_path) | |||
| if remove_finished: | |||
| os.remove(archive) | |||