#!/usr/bin/env python3 # Copyright 2021 The KubeEdge Authors. # Copyright 2020 kubeflow.org. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa import concurrent.futures import glob import gzip import json import logging import mimetypes import os import re import sys import shutil import tempfile import tarfile import zipfile import minio import requests from urllib.parse import urlparse _S3_PREFIX = "s3://" _OBS_PREFIX = "obs://" _LOCAL_PREFIX = "file://" _URI_RE = "https?://(.+)/(.+)" _HTTP_PREFIX = "http(s)://" _HEADERS_SUFFIX = "-headers" SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX) LOG = logging.getLogger(__name__) def setup_logger(): format = '%(asctime)s %(levelname)s %(funcName)s:%(lineno)s] %(message)s' logging.basicConfig(format=format) LOG.setLevel(os.getenv('LOG_LEVEL', 'INFO')) def _normalize_uri(uri: str) -> str: for src, dst in [ ("/", _LOCAL_PREFIX), (_OBS_PREFIX, _S3_PREFIX) ]: if uri.startswith(src): return uri.replace(src, dst, 1) return uri def download(uri: str, out_dir: str = None) -> str: """ Download the uri to local directory. Support procotols: http, s3. Note when uri ends with .tar.gz/.tar/.zip, this will extract it """ LOG.info("Copying contents of %s to local %s", uri, out_dir) uri = _normalize_uri(uri) if not os.path.exists(out_dir): os.makedirs(out_dir) if uri.startswith(_S3_PREFIX): download_s3(uri, out_dir) elif uri.startswith(_LOCAL_PREFIX): download_local(uri, out_dir) elif re.search(_URI_RE, uri): download_from_uri(uri, out_dir) else: raise Exception("Cannot recognize storage type for %s.\n" "%r are the current available storage type." % (uri, SUPPORT_PROTOCOLS)) LOG.info("Successfully copied %s to %s", uri, out_dir) return out_dir def indirect_download(indirect_uri: str, out_dir: str = None) -> str: """ Download the uri to local directory. Support procotols: http, s3. Note when uri ends with .tar.gz/.tar/.zip, this will extract it """ tmpdir = tempfile.mkdtemp() download(indirect_uri, tmpdir) files = os.listdir(tmpdir) if len(files) != 1: raise Exception("indirect url %s should be file, not directory" % indirect_uri) download_files = set() with open(os.path.join(tmpdir, files[0])) as f: base_uri = None for line_no, line in enumerate(f): line = line.strip() if line.startswith('#'): continue if line: if base_uri is None: base_uri = line else: file_name = line download_files.add(file_name) if not download_files: LOG.info("no files to download for indirect url %s", indirect_uri) return if not os.path.exists(out_dir): os.makedirs(out_dir) LOG.info("To download %s files IN-DIRECT %s to %s", len(download_files), indirect_uri, out_dir) uri = _normalize_uri(base_uri) # only support s3 for indirect download if uri.startswith(_S3_PREFIX): download_s3_with_multi_files(download_files, uri, out_dir) else: LOG.warning("unsupported %s for indirect url %s, skipped", uri, indirect_uri) return LOG.info("Successfully download files IN-DIRECT %s to %s", indirect_uri, out_dir) return def download_s3(uri, out_dir: str): client = _create_minio_client() count = _download_s3(client, uri, out_dir) if count == 0: raise RuntimeError("Failed to fetch files." "The path %s does not exist." % (uri)) LOG.info("downloaded %d files for %s.", count, uri) def download_s3_with_multi_files(download_files, base_uri, base_out_dir): client = _create_minio_client() total_count = 0 with concurrent.futures.ThreadPoolExecutor() as executor: todos = [] for dfile in set(download_files): dir_ = os.path.dirname(dfile) uri = base_uri.rstrip("/") + "/" + dfile out_dir = os.path.join(base_out_dir, dir_) todos.append(executor.submit(_download_s3, client, uri, out_dir)) for done in concurrent.futures.as_completed(todos): count = done.result() if count == 0: LOG.warning("failed to download %s in base uri(%s)", dfile, base_uri) continue total_count += count LOG.info("downloaded %d files for base_uri %s to local dir %s.", total_count, base_uri, base_out_dir) def _download_s3(client, uri, out_dir): """ The function downloads specified file or folder to local directory address. this function supports: 1. when downloading the specified file, keep the name of the file itself. 2. when downloading the specified folder, keep the name of the folder itself. Parameters: client: s3 client s3_url(string): url in s3, e.g. file url: s3://dev/data/data.txt, directory url: s3://dev/data out_dir(string): local directory address, e.g. /tmp/data/ Returns: int: files of number in s3_url """ bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1) bucket_name = bucket_args[0] bucket_path = len(bucket_args) > 1 and bucket_args[1] or "" objects = client.list_objects(bucket_name, prefix=bucket_path, recursive=True, use_api_v1=True) count = 0 root_path = os.path.split(os.path.normpath(bucket_path))[0] for obj in objects: # Replace any prefix from the object key with out_dir subdir_object_key = obj.object_name[len(root_path):].strip("/") # fget_object handles directory creation if does not exist if not obj.is_dir: local_file = os.path.join( out_dir, subdir_object_key or os.path.basename(obj.object_name) ) LOG.debug("downloading count:%d, file:%s", count, subdir_object_key) client.fget_object(bucket_name, obj.object_name, local_file) _extract_compress(local_file, out_dir) count += 1 return count def download_local(uri, out_dir=None): local_path = uri.replace(_LOCAL_PREFIX, "/", 1) if not os.path.exists(local_path): raise RuntimeError("Local path %s does not exist." % (uri)) if out_dir is None: return local_path elif not os.path.isdir(out_dir): os.makedirs(out_dir) if os.path.isdir(local_path): local_path = os.path.join(local_path, "*") for src in glob.glob(local_path): _, tail = os.path.split(src) dest_path = os.path.join(out_dir, tail) LOG.info("Linking: %s to %s", src, dest_path) os.symlink(src, dest_path) return out_dir def download_from_uri(uri, out_dir=None): url = urlparse(uri) filename = os.path.basename(url.path) mimetype, encoding = mimetypes.guess_type(url.path) local_path = os.path.join(out_dir, filename) if filename == '': raise ValueError('No filename contained in URI: %s' % (uri)) # Get header information from host url headers = {} host_uri = url.hostname headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}") headers = json.loads(headers_json) with requests.get(uri, stream=True, headers=headers) as response: if response.status_code != 200: raise RuntimeError("URI: %s returned a %s response code." % (uri, response.status_code)) if encoding == 'gzip': stream = gzip.GzipFile(fileobj=response.raw) local_path = os.path.join(out_dir, f'{filename}.tar') else: stream = response.raw with open(local_path, 'wb') as out: shutil.copyfileobj(stream, out) return _extract_compress(local_path, out_dir) def _extract_compress(local_path, out_dir): mimetype, encoding = mimetypes.guess_type(local_path) if mimetype in ["application/x-tar", "application/zip"]: if mimetype == "application/x-tar": archive = tarfile.open(local_path, 'r', encoding='utf-8') else: archive = zipfile.ZipFile(local_path, 'r') archive.extractall(out_dir) archive.close() os.remove(local_path) return out_dir def _create_minio_client(): url = urlparse(os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com")) use_ssl = url.scheme == 'https' if url.scheme else True return minio.Minio( url.netloc, access_key=os.getenv("ACCESS_KEY_ID", ""), secret_key=os.getenv("SECRET_ACCESS_KEY", ""), secure=use_ssl ) def main(): setup_logger() if len(sys.argv) < 2 or len(sys.argv) % 2 == 0: LOG.error("Usage: download.py " "src_uri dest_path [src_uri dest_path]") sys.exit(1) indirect_mark = os.getenv("INDIRECT_URL_MARK", "@") for i in range(1, len(sys.argv)-1, 2): src_uri = sys.argv[i] dest_path = sys.argv[i+1] LOG.info("Initializing, args: src_uri [%s] dest_path [%s]" % (src_uri, dest_path)) if dest_path.startswith(indirect_mark): indirect_download(src_uri, dest_path[len(indirect_mark):]) else: download(src_uri, dest_path) main()