You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

download.py 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. #!/usr/bin/env python3
  2. # Copyright 2021 The KubeEdge Authors.
  3. # Copyright 2020 kubeflow.org.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa
  17. import concurrent.futures
  18. import glob
  19. import gzip
  20. import json
  21. import logging
  22. import mimetypes
  23. import os
  24. import re
  25. import sys
  26. import shutil
  27. import tempfile
  28. import tarfile
  29. import zipfile
  30. import minio
  31. import requests
  32. from urllib.parse import urlparse
  33. _S3_PREFIX = "s3://"
  34. _OBS_PREFIX = "obs://"
  35. _LOCAL_PREFIX = "file://"
  36. _URI_RE = "https?://(.+)/(.+)"
  37. _HTTP_PREFIX = "http(s)://"
  38. _HEADERS_SUFFIX = "-headers"
  39. SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX)
  40. LOG = logging.getLogger(__name__)
  41. def setup_logger():
  42. format = '%(asctime)s %(levelname)s %(funcName)s:%(lineno)s] %(message)s'
  43. logging.basicConfig(format=format)
  44. LOG.setLevel(os.getenv('LOG_LEVEL', 'INFO'))
  45. def _normalize_uri(uri: str) -> str:
  46. for src, dst in [
  47. ("/", _LOCAL_PREFIX),
  48. (_OBS_PREFIX, _S3_PREFIX)
  49. ]:
  50. if uri.startswith(src):
  51. return uri.replace(src, dst, 1)
  52. return uri
  53. def download(uri: str, out_dir: str = None) -> str:
  54. """ Download the uri to local directory.
  55. Support procotols: http, s3.
  56. Note when uri ends with .tar.gz/.tar/.zip, this will extract it
  57. """
  58. LOG.info("Copying contents of %s to local %s", uri, out_dir)
  59. uri = _normalize_uri(uri)
  60. if not os.path.exists(out_dir):
  61. os.makedirs(out_dir)
  62. if uri.startswith(_S3_PREFIX):
  63. download_s3(uri, out_dir)
  64. elif uri.startswith(_LOCAL_PREFIX):
  65. download_local(uri, out_dir)
  66. elif re.search(_URI_RE, uri):
  67. download_from_uri(uri, out_dir)
  68. else:
  69. raise Exception("Cannot recognize storage type for %s.\n"
  70. "%r are the current available storage type." %
  71. (uri, SUPPORT_PROTOCOLS))
  72. LOG.info("Successfully copied %s to %s", uri, out_dir)
  73. return out_dir
  74. def indirect_download(indirect_uri: str, out_dir: str = None) -> str:
  75. """ Download the uri to local directory.
  76. Support procotols: http, s3.
  77. Note when uri ends with .tar.gz/.tar/.zip, this will extract it
  78. """
  79. tmpdir = tempfile.mkdtemp()
  80. download(indirect_uri, tmpdir)
  81. files = os.listdir(tmpdir)
  82. if len(files) != 1:
  83. raise Exception("indirect url %s should be file, not directory"
  84. % indirect_uri)
  85. download_files = set()
  86. with open(os.path.join(tmpdir, files[0])) as f:
  87. base_uri = None
  88. for line_no, line in enumerate(f):
  89. line = line.strip()
  90. if line.startswith('#'):
  91. continue
  92. if line:
  93. if base_uri is None:
  94. base_uri = line
  95. else:
  96. file_name = line
  97. download_files.add(file_name)
  98. if not download_files:
  99. LOG.info("no files to download for indirect url %s",
  100. indirect_uri)
  101. return
  102. if not os.path.exists(out_dir):
  103. os.makedirs(out_dir)
  104. LOG.info("To download %s files IN-DIRECT %s to %s",
  105. len(download_files), indirect_uri, out_dir)
  106. uri = _normalize_uri(base_uri)
  107. # only support s3 for indirect download
  108. if uri.startswith(_S3_PREFIX):
  109. download_s3_with_multi_files(download_files, uri, out_dir)
  110. else:
  111. LOG.warning("unsupported %s for indirect url %s, skipped",
  112. uri, indirect_uri)
  113. return
  114. LOG.info("Successfully download files IN-DIRECT %s to %s",
  115. indirect_uri, out_dir)
  116. return
  117. def download_s3(uri, out_dir: str):
  118. client = _create_minio_client()
  119. count = _download_s3(client, uri, out_dir)
  120. if count == 0:
  121. raise RuntimeError("Failed to fetch files."
  122. "The path %s does not exist." % (uri))
  123. LOG.info("downloaded %d files for %s.", count, uri)
  124. def download_s3_with_multi_files(download_files,
  125. base_uri, base_out_dir):
  126. client = _create_minio_client()
  127. total_count = 0
  128. with concurrent.futures.ThreadPoolExecutor() as executor:
  129. todos = []
  130. for dfile in set(download_files):
  131. dir_ = os.path.dirname(dfile)
  132. uri = base_uri.rstrip("/") + "/" + dfile
  133. out_dir = os.path.join(base_out_dir, dir_)
  134. todos.append(executor.submit(_download_s3, client, uri, out_dir))
  135. for done in concurrent.futures.as_completed(todos):
  136. count = done.result()
  137. if count == 0:
  138. LOG.warning("failed to download %s in base uri(%s)",
  139. dfile, base_uri)
  140. continue
  141. total_count += count
  142. LOG.info("downloaded %d files for base_uri %s to local dir %s.",
  143. total_count, base_uri, base_out_dir)
  144. def _download_s3(client, uri, out_dir):
  145. bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1)
  146. bucket_name = bucket_args[0]
  147. bucket_path = len(bucket_args) > 1 and bucket_args[1] or ""
  148. objects = client.list_objects(bucket_name,
  149. prefix=bucket_path,
  150. recursive=True,
  151. use_api_v1=True)
  152. count = 0
  153. for obj in objects:
  154. # Replace any prefix from the object key with out_dir
  155. subdir_object_key = obj.object_name[len(bucket_path):].strip("/")
  156. # fget_object handles directory creation if does not exist
  157. if not obj.is_dir:
  158. local_file = os.path.join(
  159. out_dir,
  160. subdir_object_key or os.path.basename(obj.object_name)
  161. )
  162. LOG.debug("downloading count:%d, file:%s",
  163. count, subdir_object_key)
  164. client.fget_object(bucket_name, obj.object_name, local_file)
  165. _extract_compress(local_file, out_dir)
  166. count += 1
  167. return count
  168. def download_local(uri, out_dir=None):
  169. local_path = uri.replace(_LOCAL_PREFIX, "/", 1)
  170. if not os.path.exists(local_path):
  171. raise RuntimeError("Local path %s does not exist." % (uri))
  172. if out_dir is None:
  173. return local_path
  174. elif not os.path.isdir(out_dir):
  175. os.makedirs(out_dir)
  176. if os.path.isdir(local_path):
  177. local_path = os.path.join(local_path, "*")
  178. for src in glob.glob(local_path):
  179. _, tail = os.path.split(src)
  180. dest_path = os.path.join(out_dir, tail)
  181. LOG.info("Linking: %s to %s", src, dest_path)
  182. os.symlink(src, dest_path)
  183. return out_dir
  184. def download_from_uri(uri, out_dir=None):
  185. url = urlparse(uri)
  186. filename = os.path.basename(url.path)
  187. mimetype, encoding = mimetypes.guess_type(url.path)
  188. local_path = os.path.join(out_dir, filename)
  189. if filename == '':
  190. raise ValueError('No filename contained in URI: %s' % (uri))
  191. # Get header information from host url
  192. headers = {}
  193. host_uri = url.hostname
  194. headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}")
  195. headers = json.loads(headers_json)
  196. with requests.get(uri, stream=True, headers=headers) as response:
  197. if response.status_code != 200:
  198. raise RuntimeError("URI: %s returned a %s response code." %
  199. (uri, response.status_code))
  200. if encoding == 'gzip':
  201. stream = gzip.GzipFile(fileobj=response.raw)
  202. local_path = os.path.join(out_dir, f'{filename}.tar')
  203. else:
  204. stream = response.raw
  205. with open(local_path, 'wb') as out:
  206. shutil.copyfileobj(stream, out)
  207. return _extract_compress(local_path, out_dir)
  208. def _extract_compress(local_path, out_dir):
  209. mimetype, encoding = mimetypes.guess_type(local_path)
  210. if mimetype in ["application/x-tar", "application/zip"]:
  211. if mimetype == "application/x-tar":
  212. archive = tarfile.open(local_path, 'r', encoding='utf-8')
  213. else:
  214. archive = zipfile.ZipFile(local_path, 'r')
  215. archive.extractall(out_dir)
  216. archive.close()
  217. os.remove(local_path)
  218. return out_dir
  219. def _create_minio_client():
  220. url = urlparse(os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com"))
  221. use_ssl = url.scheme == 'https' if url.scheme else True
  222. return minio.Minio(
  223. url.netloc,
  224. access_key=os.getenv("ACCESS_KEY_ID", ""),
  225. secret_key=os.getenv("SECRET_ACCESS_KEY", ""),
  226. secure=use_ssl
  227. )
  228. def main():
  229. setup_logger()
  230. if len(sys.argv) < 2 or len(sys.argv) % 2 == 0:
  231. LOG.error("Usage: download.py "
  232. "src_uri dest_path [src_uri dest_path]")
  233. sys.exit(1)
  234. indirect_mark = os.getenv("INDIRECT_URL_MARK", "@")
  235. for i in range(1, len(sys.argv)-1, 2):
  236. src_uri = sys.argv[i]
  237. dest_path = sys.argv[i+1]
  238. LOG.info("Initializing, args: src_uri [%s] dest_path [%s]" %
  239. (src_uri, dest_path))
  240. if dest_path.startswith(indirect_mark):
  241. indirect_download(src_uri, dest_path[len(indirect_mark):])
  242. else:
  243. download(src_uri, dest_path)
  244. main()