|
- #!/usr/bin/env python3
-
- # Copyright 2021 The KubeEdge Authors.
- # Copyright 2020 kubeflow.org.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- # modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa
-
- import glob
- import gzip
- import json
- import logging
- import mimetypes
- import os
- import re
- import sys
- import shutil
- import tempfile
- import tarfile
- import zipfile
-
- import minio
- import requests
- from urllib.parse import urlparse
-
- _S3_PREFIX = "s3://"
- _OBS_PREFIX = "obs://"
- _LOCAL_PREFIX = "file://"
- _URI_RE = "https?://(.+)/(.+)"
- _HTTP_PREFIX = "http(s)://"
- _HEADERS_SUFFIX = "-headers"
-
- SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX)
-
-
- def _normalize_uri(uri: str) -> str:
- for src, dst in [
- ("/", _LOCAL_PREFIX),
- (_OBS_PREFIX, _S3_PREFIX)
- ]:
- if uri.startswith(src):
- return uri.replace(src, dst, 1)
- return uri
-
-
- def download(uri: str, out_dir: str = None) -> str:
- """ Download the uri to local directory.
-
- Support procotols: http, s3.
- Note when uri ends with .tar.gz/.tar/.zip, this will extract it
- """
- logging.info("Copying contents of %s to local %s", uri, out_dir)
-
- uri = _normalize_uri(uri)
-
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
-
- if uri.startswith(_S3_PREFIX):
- _download_s3(uri, out_dir)
- elif uri.startswith(_LOCAL_PREFIX):
- _download_local(uri, out_dir)
- elif re.search(_URI_RE, uri):
- _download_from_uri(uri, out_dir)
- else:
- raise Exception("Cannot recognize storage type for %s.\n"
- "%r are the current available storage type." %
- (uri, SUPPORT_PROTOCOLS))
-
- logging.info("Successfully copied %s to %s", uri, out_dir)
- return out_dir
-
-
- def _download_s3(uri, out_dir: str):
- client = _create_minio_client()
- bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1)
- bucket_name = bucket_args[0]
- bucket_path = len(bucket_args) > 1 and bucket_args[1] or ""
- objects = client.list_objects(bucket_name,
- prefix=bucket_path,
- recursive=True)
- count = 0
- for obj in objects:
- # Replace any prefix from the object key with out_dir
- subdir_object_key = obj.object_name[len(bucket_path):].strip("/")
- # fget_object handles directory creation if does not exist
- if not obj.is_dir:
- local_file = os.path.join(
- out_dir,
- subdir_object_key or os.path.basename(obj.object_name)
- )
- client.fget_object(bucket_name, obj.object_name, local_file)
- _extract_compress(local_file, out_dir)
- count += 1
- if count == 0:
- raise RuntimeError("Failed to fetch model. \
- The path or model %s does not exist." % (uri))
-
-
- def _download_local(uri, out_dir=None):
- local_path = uri.replace(_LOCAL_PREFIX, "", 1)
- if not os.path.exists(local_path):
- raise RuntimeError("Local path %s does not exist." % (uri))
-
- if out_dir is None:
- return local_path
- elif not os.path.isdir(out_dir):
- os.makedirs(out_dir)
-
- if os.path.isdir(local_path):
- local_path = os.path.join(local_path, "*")
-
- for src in glob.glob(local_path):
- _, tail = os.path.split(src)
- dest_path = os.path.join(out_dir, tail)
- logging.info("Linking: %s to %s", src, dest_path)
- os.symlink(src, dest_path)
- return out_dir
-
-
- def _download_from_uri(uri, out_dir=None):
- url = urlparse(uri)
- filename = os.path.basename(url.path)
- mimetype, encoding = mimetypes.guess_type(url.path)
- local_path = os.path.join(out_dir, filename)
-
- if filename == '':
- raise ValueError('No filename contained in URI: %s' % (uri))
-
- # Get header information from host url
- headers = {}
- host_uri = url.hostname
-
- headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}")
- headers = json.loads(headers_json)
-
- with requests.get(uri, stream=True, headers=headers) as response:
- if response.status_code != 200:
- raise RuntimeError("URI: %s returned a %s response code." %
- (uri, response.status_code))
-
- if encoding == 'gzip':
- stream = gzip.GzipFile(fileobj=response.raw)
- local_path = os.path.join(out_dir, f'{filename}.tar')
- else:
- stream = response.raw
- with open(local_path, 'wb') as out:
- shutil.copyfileobj(stream, out)
- return _extract_compress(local_path, out_dir)
-
-
- def _extract_compress(local_path, out_dir):
- mimetype, encoding = mimetypes.guess_type(local_path)
- if mimetype in ["application/x-tar", "application/zip"]:
- if mimetype == "application/x-tar":
- archive = tarfile.open(local_path, 'r', encoding='utf-8')
- else:
- archive = zipfile.ZipFile(local_path, 'r')
- archive.extractall(out_dir)
- archive.close()
- os.remove(local_path)
-
- return out_dir
-
-
- def _create_minio_client():
- # Adding prefixing "http" in urlparse is necessary for it to be the netloc
- url = urlparse(os.getenv("AWS_ENDPOINT_URL", "http://s3.amazonaws.com"))
- use_ssl = (url.scheme == 'https' if url.scheme
- else os.getenv("S3_USE_HTTPS", "true") == "true")
- return minio.Minio(
- url.netloc,
- access_key=os.getenv("AWS_ACCESS_KEY_ID", ""),
- secret_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
- region=os.getenv("AWS_REGION", ""),
- secure=use_ssl
- )
-
-
- def main():
- if len(sys.argv) < 2 or len(sys.argv) % 2 == 0:
- logging.error("Usage: initializer-entrypoint "
- "src_uri dest_path [src_uri dest_path]")
- sys.exit(1)
-
- for i in range(1, len(sys.argv)-1, 2):
- src_uri = sys.argv[i]
- dest_path = sys.argv[i+1]
-
- logging.info("Initializing, args: src_uri [%s] dest_path [%s]" %
- (src_uri, dest_path))
- download(src_uri, dest_path)
-
-
- main()
|