You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

download.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #!/usr/bin/env python3
  2. # Copyright 2021 The KubeEdge Authors.
  3. # Copyright 2020 kubeflow.org.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa
  17. import glob
  18. import gzip
  19. import json
  20. import logging
  21. import mimetypes
  22. import os
  23. import re
  24. import sys
  25. import shutil
  26. import tempfile
  27. import tarfile
  28. import zipfile
  29. import minio
  30. import requests
  31. from urllib.parse import urlparse
  32. _S3_PREFIX = "s3://"
  33. _OBS_PREFIX = "obs://"
  34. _LOCAL_PREFIX = "file://"
  35. _URI_RE = "https?://(.+)/(.+)"
  36. _HTTP_PREFIX = "http(s)://"
  37. _HEADERS_SUFFIX = "-headers"
  38. SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX)
  39. def _normalize_uri(uri: str) -> str:
  40. for src, dst in [
  41. ("/", _LOCAL_PREFIX),
  42. (_OBS_PREFIX, _S3_PREFIX)
  43. ]:
  44. if uri.startswith(src):
  45. return uri.replace(src, dst, 1)
  46. return uri
  47. def download(uri: str, out_dir: str = None) -> str:
  48. """ Download the uri to local directory.
  49. Support procotols: http, s3.
  50. Note when uri ends with .tar.gz/.tar/.zip, this will extract it
  51. """
  52. logging.info("Copying contents of %s to local %s", uri, out_dir)
  53. uri = _normalize_uri(uri)
  54. if not os.path.exists(out_dir):
  55. os.makedirs(out_dir)
  56. if uri.startswith(_S3_PREFIX):
  57. _download_s3(uri, out_dir)
  58. elif uri.startswith(_LOCAL_PREFIX):
  59. _download_local(uri, out_dir)
  60. elif re.search(_URI_RE, uri):
  61. _download_from_uri(uri, out_dir)
  62. else:
  63. raise Exception("Cannot recognize storage type for %s.\n"
  64. "%r are the current available storage type." %
  65. (uri, SUPPORT_PROTOCOLS))
  66. logging.info("Successfully copied %s to %s", uri, out_dir)
  67. return out_dir
  68. def _download_s3(uri, out_dir: str):
  69. client = _create_minio_client()
  70. bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1)
  71. bucket_name = bucket_args[0]
  72. bucket_path = len(bucket_args) > 1 and bucket_args[1] or ""
  73. objects = client.list_objects(bucket_name,
  74. prefix=bucket_path,
  75. recursive=True)
  76. count = 0
  77. for obj in objects:
  78. # Replace any prefix from the object key with out_dir
  79. subdir_object_key = obj.object_name[len(bucket_path):].strip("/")
  80. # fget_object handles directory creation if does not exist
  81. if not obj.is_dir:
  82. local_file = os.path.join(
  83. out_dir,
  84. subdir_object_key or os.path.basename(obj.object_name)
  85. )
  86. client.fget_object(bucket_name, obj.object_name, local_file)
  87. _extract_compress(local_file, out_dir)
  88. count += 1
  89. if count == 0:
  90. raise RuntimeError("Failed to fetch model. \
  91. The path or model %s does not exist." % (uri))
  92. def _download_local(uri, out_dir=None):
  93. local_path = uri.replace(_LOCAL_PREFIX, "", 1)
  94. if not os.path.exists(local_path):
  95. raise RuntimeError("Local path %s does not exist." % (uri))
  96. if out_dir is None:
  97. return local_path
  98. elif not os.path.isdir(out_dir):
  99. os.makedirs(out_dir)
  100. if os.path.isdir(local_path):
  101. local_path = os.path.join(local_path, "*")
  102. for src in glob.glob(local_path):
  103. _, tail = os.path.split(src)
  104. dest_path = os.path.join(out_dir, tail)
  105. logging.info("Linking: %s to %s", src, dest_path)
  106. os.symlink(src, dest_path)
  107. return out_dir
  108. def _download_from_uri(uri, out_dir=None):
  109. url = urlparse(uri)
  110. filename = os.path.basename(url.path)
  111. mimetype, encoding = mimetypes.guess_type(url.path)
  112. local_path = os.path.join(out_dir, filename)
  113. if filename == '':
  114. raise ValueError('No filename contained in URI: %s' % (uri))
  115. # Get header information from host url
  116. headers = {}
  117. host_uri = url.hostname
  118. headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}")
  119. headers = json.loads(headers_json)
  120. with requests.get(uri, stream=True, headers=headers) as response:
  121. if response.status_code != 200:
  122. raise RuntimeError("URI: %s returned a %s response code." %
  123. (uri, response.status_code))
  124. if encoding == 'gzip':
  125. stream = gzip.GzipFile(fileobj=response.raw)
  126. local_path = os.path.join(out_dir, f'{filename}.tar')
  127. else:
  128. stream = response.raw
  129. with open(local_path, 'wb') as out:
  130. shutil.copyfileobj(stream, out)
  131. return _extract_compress(local_path, out_dir)
  132. def _extract_compress(local_path, out_dir):
  133. mimetype, encoding = mimetypes.guess_type(local_path)
  134. if mimetype in ["application/x-tar", "application/zip"]:
  135. if mimetype == "application/x-tar":
  136. archive = tarfile.open(local_path, 'r', encoding='utf-8')
  137. else:
  138. archive = zipfile.ZipFile(local_path, 'r')
  139. archive.extractall(out_dir)
  140. archive.close()
  141. os.remove(local_path)
  142. return out_dir
  143. def _create_minio_client():
  144. # Adding prefixing "http" in urlparse is necessary for it to be the netloc
  145. url = urlparse(os.getenv("AWS_ENDPOINT_URL", "http://s3.amazonaws.com"))
  146. use_ssl = (url.scheme == 'https' if url.scheme
  147. else os.getenv("S3_USE_HTTPS", "true") == "true")
  148. return minio.Minio(
  149. url.netloc,
  150. access_key=os.getenv("AWS_ACCESS_KEY_ID", ""),
  151. secret_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
  152. region=os.getenv("AWS_REGION", ""),
  153. secure=use_ssl
  154. )
  155. def main():
  156. if len(sys.argv) < 2 or len(sys.argv) % 2 == 0:
  157. logging.error("Usage: initializer-entrypoint "
  158. "src_uri dest_path [src_uri dest_path]")
  159. sys.exit(1)
  160. for i in range(1, len(sys.argv)-1, 2):
  161. src_uri = sys.argv[i]
  162. dest_path = sys.argv[i+1]
  163. logging.info("Initializing, args: src_uri [%s] dest_path [%s]" %
  164. (src_uri, dest_path))
  165. download(src_uri, dest_path)
  166. main()