|
|
|
@@ -0,0 +1,206 @@ |
|
|
|
#!/usr/bin/env python3 |
|
|
|
|
|
|
|
# Copyright 2021 The KubeEdge Authors. |
|
|
|
# Copyright 2020 kubeflow.org. |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
# modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa |
|
|
|
|
|
|
|
import glob |
|
|
|
import gzip |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import mimetypes |
|
|
|
import os |
|
|
|
import re |
|
|
|
import sys |
|
|
|
import shutil |
|
|
|
import tempfile |
|
|
|
import tarfile |
|
|
|
import zipfile |
|
|
|
|
|
|
|
import minio |
|
|
|
import requests |
|
|
|
from urllib.parse import urlparse |
|
|
|
|
|
|
|
_S3_PREFIX = "s3://" |
|
|
|
_OBS_PREFIX = "obs://" |
|
|
|
_LOCAL_PREFIX = "file://" |
|
|
|
_URI_RE = "https?://(.+)/(.+)" |
|
|
|
_HTTP_PREFIX = "http(s)://" |
|
|
|
_HEADERS_SUFFIX = "-headers" |
|
|
|
|
|
|
|
SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX) |
|
|
|
|
|
|
|
|
|
|
|
def _normalize_uri(uri: str) -> str: |
|
|
|
for src, dst in [ |
|
|
|
("/", _LOCAL_PREFIX), |
|
|
|
(_OBS_PREFIX, _S3_PREFIX) |
|
|
|
]: |
|
|
|
if uri.startswith(src): |
|
|
|
return uri.replace(src, dst, 1) |
|
|
|
return uri |
|
|
|
|
|
|
|
|
|
|
|
def download(uri: str, out_dir: str = None) -> str: |
|
|
|
""" Download the uri to local directory. |
|
|
|
|
|
|
|
Support procotols: http, s3. |
|
|
|
Note when uri ends with .tar.gz/.tar/.zip, this will extract it |
|
|
|
""" |
|
|
|
logging.info("Copying contents of %s to local %s", uri, out_dir) |
|
|
|
|
|
|
|
uri = _normalize_uri(uri) |
|
|
|
|
|
|
|
if not os.path.exists(out_dir): |
|
|
|
os.makedirs(out_dir) |
|
|
|
|
|
|
|
if uri.startswith(_S3_PREFIX): |
|
|
|
_download_s3(uri, out_dir) |
|
|
|
elif uri.startswith(_LOCAL_PREFIX): |
|
|
|
_download_local(uri, out_dir) |
|
|
|
elif re.search(_URI_RE, uri): |
|
|
|
_download_from_uri(uri, out_dir) |
|
|
|
else: |
|
|
|
raise Exception("Cannot recognize storage type for %s.\n" |
|
|
|
"%r are the current available storage type." % |
|
|
|
(uri, SUPPORT_PROTOCOLS)) |
|
|
|
|
|
|
|
logging.info("Successfully copied %s to %s", uri, out_dir) |
|
|
|
return out_dir |
|
|
|
|
|
|
|
|
|
|
|
def _download_s3(uri, out_dir: str): |
|
|
|
client = _create_minio_client() |
|
|
|
bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1) |
|
|
|
bucket_name = bucket_args[0] |
|
|
|
bucket_path = len(bucket_args) > 1 and bucket_args[1] or "" |
|
|
|
objects = client.list_objects(bucket_name, |
|
|
|
prefix=bucket_path, |
|
|
|
recursive=True) |
|
|
|
count = 0 |
|
|
|
for obj in objects: |
|
|
|
# Replace any prefix from the object key with out_dir |
|
|
|
subdir_object_key = obj.object_name[len(bucket_path):].strip("/") |
|
|
|
# fget_object handles directory creation if does not exist |
|
|
|
if not obj.is_dir: |
|
|
|
local_file = os.path.join( |
|
|
|
out_dir, |
|
|
|
subdir_object_key or os.path.basename(obj.object_name) |
|
|
|
) |
|
|
|
client.fget_object(bucket_name, obj.object_name, local_file) |
|
|
|
_extract_compress(local_file, out_dir) |
|
|
|
count += 1 |
|
|
|
if count == 0: |
|
|
|
raise RuntimeError("Failed to fetch model. \ |
|
|
|
The path or model %s does not exist." % (uri)) |
|
|
|
|
|
|
|
|
|
|
|
def _download_local(uri, out_dir=None): |
|
|
|
local_path = uri.replace(_LOCAL_PREFIX, "", 1) |
|
|
|
if not os.path.exists(local_path): |
|
|
|
raise RuntimeError("Local path %s does not exist." % (uri)) |
|
|
|
|
|
|
|
if out_dir is None: |
|
|
|
return local_path |
|
|
|
elif not os.path.isdir(out_dir): |
|
|
|
os.makedirs(out_dir) |
|
|
|
|
|
|
|
if os.path.isdir(local_path): |
|
|
|
local_path = os.path.join(local_path, "*") |
|
|
|
|
|
|
|
for src in glob.glob(local_path): |
|
|
|
_, tail = os.path.split(src) |
|
|
|
dest_path = os.path.join(out_dir, tail) |
|
|
|
logging.info("Linking: %s to %s", src, dest_path) |
|
|
|
os.symlink(src, dest_path) |
|
|
|
return out_dir |
|
|
|
|
|
|
|
|
|
|
|
def _download_from_uri(uri, out_dir=None): |
|
|
|
url = urlparse(uri) |
|
|
|
filename = os.path.basename(url.path) |
|
|
|
mimetype, encoding = mimetypes.guess_type(url.path) |
|
|
|
local_path = os.path.join(out_dir, filename) |
|
|
|
|
|
|
|
if filename == '': |
|
|
|
raise ValueError('No filename contained in URI: %s' % (uri)) |
|
|
|
|
|
|
|
# Get header information from host url |
|
|
|
headers = {} |
|
|
|
host_uri = url.hostname |
|
|
|
|
|
|
|
headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}") |
|
|
|
headers = json.loads(headers_json) |
|
|
|
|
|
|
|
with requests.get(uri, stream=True, headers=headers) as response: |
|
|
|
if response.status_code != 200: |
|
|
|
raise RuntimeError("URI: %s returned a %s response code." % |
|
|
|
(uri, response.status_code)) |
|
|
|
|
|
|
|
if encoding == 'gzip': |
|
|
|
stream = gzip.GzipFile(fileobj=response.raw) |
|
|
|
local_path = os.path.join(out_dir, f'{filename}.tar') |
|
|
|
else: |
|
|
|
stream = response.raw |
|
|
|
with open(local_path, 'wb') as out: |
|
|
|
shutil.copyfileobj(stream, out) |
|
|
|
return _extract_compress(local_path, out_dir) |
|
|
|
|
|
|
|
|
|
|
|
def _extract_compress(local_path, out_dir): |
|
|
|
mimetype, encoding = mimetypes.guess_type(local_path) |
|
|
|
if mimetype in ["application/x-tar", "application/zip"]: |
|
|
|
if mimetype == "application/x-tar": |
|
|
|
archive = tarfile.open(local_path, 'r', encoding='utf-8') |
|
|
|
else: |
|
|
|
archive = zipfile.ZipFile(local_path, 'r') |
|
|
|
archive.extractall(out_dir) |
|
|
|
archive.close() |
|
|
|
os.remove(local_path) |
|
|
|
|
|
|
|
return out_dir |
|
|
|
|
|
|
|
|
|
|
|
def _create_minio_client(): |
|
|
|
# Adding prefixing "http" in urlparse is necessary for it to be the netloc |
|
|
|
url = urlparse(os.getenv("AWS_ENDPOINT_URL", "http://s3.amazonaws.com")) |
|
|
|
use_ssl = (url.scheme == 'https' if url.scheme |
|
|
|
else os.getenv("S3_USE_HTTPS", "true") == "true") |
|
|
|
return minio.Minio( |
|
|
|
url.netloc, |
|
|
|
access_key=os.getenv("AWS_ACCESS_KEY_ID", ""), |
|
|
|
secret_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""), |
|
|
|
region=os.getenv("AWS_REGION", ""), |
|
|
|
secure=use_ssl |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
if len(sys.argv) < 2 or len(sys.argv) % 2 == 0: |
|
|
|
logging.error("Usage: initializer-entrypoint " |
|
|
|
"src_uri dest_path [src_uri dest_path]") |
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
for i in range(1, len(sys.argv)-1, 2): |
|
|
|
src_uri = sys.argv[i] |
|
|
|
dest_path = sys.argv[i+1] |
|
|
|
|
|
|
|
logging.info("Initializing, args: src_uri [%s] dest_path [%s]" % |
|
|
|
(src_uri, dest_path)) |
|
|
|
download(src_uri, dest_path) |
|
|
|
|
|
|
|
|
|
|
|
main() |