You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_download.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import os
  4. import sys
  5. import tempfile
  6. from functools import partial
  7. from http.cookiejar import CookieJar
  8. from pathlib import Path
  9. from typing import Dict, Optional, Union
  10. from uuid import uuid4
  11. import requests
  12. from filelock import FileLock
  13. from tqdm import tqdm
  14. from modelscope import __version__
  15. from modelscope.utils.constant import DEFAULT_MODEL_REVISION
  16. from modelscope.utils.logger import get_logger
  17. from .api import HubApi, ModelScopeConfig
  18. from .constants import FILE_HASH
  19. from .errors import FileDownloadError, NotExistError
  20. from .utils.caching import ModelFileSystemCache
  21. from .utils.utils import (file_integrity_validation, get_cache_dir,
  22. get_endpoint, model_id_to_group_owner_name)
  23. SESSION_ID = uuid4().hex
  24. logger = get_logger()
  25. def model_file_download(
  26. model_id: str,
  27. file_path: str,
  28. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  29. cache_dir: Optional[str] = None,
  30. user_agent: Union[Dict, str, None] = None,
  31. local_files_only: Optional[bool] = False,
  32. ) -> Optional[str]: # pragma: no cover
  33. """
  34. Download from a given URL and cache it if it's not already present in the
  35. local cache.
  36. Given a URL, this function looks for the corresponding file in the local
  37. cache. If it's not there, download it. Then return the path to the cached
  38. file.
  39. Args:
  40. model_id (`str`):
  41. The model to whom the file to be downloaded belongs.
  42. file_path(`str`):
  43. Path of the file to be downloaded, relative to the root of model repo
  44. revision(`str`, *optional*):
  45. revision of the model file to be downloaded.
  46. Can be any of a branch, tag or commit hash
  47. cache_dir (`str`, `Path`, *optional*):
  48. Path to the folder where cached files are stored.
  49. user_agent (`dict`, `str`, *optional*):
  50. The user-agent info in the form of a dictionary or a string.
  51. local_files_only (`bool`, *optional*, defaults to `False`):
  52. If `True`, avoid downloading the file and return the path to the
  53. local cached file if it exists.
  54. if `False`, download the file anyway even it exists
  55. Returns:
  56. Local path (string) of file or if networking is off, last version of
  57. file cached on disk.
  58. <Tip>
  59. Raises the following errors:
  60. - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  61. if `use_auth_token=True` and the token cannot be found.
  62. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
  63. if ETag cannot be determined.
  64. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  65. if some parameter value is invalid
  66. </Tip>
  67. """
  68. if cache_dir is None:
  69. cache_dir = get_cache_dir()
  70. if isinstance(cache_dir, Path):
  71. cache_dir = str(cache_dir)
  72. temporary_cache_dir = os.path.join(cache_dir, 'temp')
  73. os.makedirs(temporary_cache_dir, exist_ok=True)
  74. group_or_owner, name = model_id_to_group_owner_name(model_id)
  75. cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
  76. # if local_files_only is `True` and the file already exists in cached_path
  77. # return the cached path
  78. if local_files_only:
  79. cached_file_path = cache.get_file_by_path(file_path)
  80. if cached_file_path is not None:
  81. logger.warning(
  82. "File exists in local cache, but we're not sure it's up to date"
  83. )
  84. return cached_file_path
  85. else:
  86. raise ValueError(
  87. 'Cannot find the requested files in the cached path and outgoing'
  88. ' traffic has been disabled. To enable model look-ups and downloads'
  89. " online, set 'local_files_only' to False.")
  90. _api = HubApi()
  91. headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
  92. cookies = ModelScopeConfig.get_cookies()
  93. branches, tags = _api.get_model_branches_and_tags(
  94. model_id, use_cookies=False if cookies is None else cookies)
  95. file_to_download_info = None
  96. is_commit_id = False
  97. if revision in branches or revision in tags: # The revision is version or tag,
  98. # we need to confirm the version is up to date
  99. # we need to get the file list to check if the lateast version is cached, if so return, otherwise download
  100. model_files = _api.get_model_files(
  101. model_id=model_id,
  102. revision=revision,
  103. recursive=True,
  104. use_cookies=False if cookies is None else cookies)
  105. for model_file in model_files:
  106. if model_file['Type'] == 'tree':
  107. continue
  108. if model_file['Path'] == file_path:
  109. if cache.exists(model_file):
  110. return cache.get_file_by_info(model_file)
  111. else:
  112. file_to_download_info = model_file
  113. break
  114. if file_to_download_info is None:
  115. raise NotExistError('The file path: %s not exist in: %s' %
  116. (file_path, model_id))
  117. else: # the revision is commit id.
  118. cached_file_path = cache.get_file_by_path_and_commit_id(
  119. file_path, revision)
  120. if cached_file_path is not None:
  121. file_name = os.path.basename(cached_file_path)
  122. logger.info(
  123. f'File {file_name} already in cache, skip downloading!')
  124. return cached_file_path # the file is in cache.
  125. is_commit_id = True
  126. # we need to download again
  127. url_to_download = get_file_download_url(model_id, file_path, revision)
  128. file_to_download_info = {
  129. 'Path':
  130. file_path,
  131. 'Revision':
  132. revision if is_commit_id else file_to_download_info['Revision'],
  133. FILE_HASH:
  134. None if (is_commit_id or FILE_HASH not in file_to_download_info) else
  135. file_to_download_info[FILE_HASH]
  136. }
  137. temp_file_name = next(tempfile._get_candidate_names())
  138. http_get_file(
  139. url_to_download,
  140. temporary_cache_dir,
  141. temp_file_name,
  142. headers=headers,
  143. cookies=None if cookies is None else cookies.get_dict())
  144. temp_file_path = os.path.join(temporary_cache_dir, temp_file_name)
  145. # for download with commit we can't get Sha256
  146. if file_to_download_info[FILE_HASH] is not None:
  147. file_integrity_validation(temp_file_path,
  148. file_to_download_info[FILE_HASH])
  149. return cache.put_file(file_to_download_info,
  150. os.path.join(temporary_cache_dir, temp_file_name))
  151. def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
  152. """Formats a user-agent string with basic info about a request.
  153. Args:
  154. user_agent (`str`, `dict`, *optional*):
  155. The user agent info in the form of a dictionary or a single string.
  156. Returns:
  157. The formatted user-agent string.
  158. """
  159. ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}'
  160. if isinstance(user_agent, dict):
  161. ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
  162. elif isinstance(user_agent, str):
  163. ua = user_agent
  164. return ua
  165. def get_file_download_url(model_id: str, file_path: str, revision: str):
  166. """
  167. Format file download url according to `model_id`, `revision` and `file_path`.
  168. e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
  169. the resulted download url is: https://modelscope.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
  170. """
  171. download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
  172. return download_url_template.format(
  173. endpoint=get_endpoint(),
  174. model_id=model_id,
  175. revision=revision,
  176. file_path=file_path,
  177. )
  178. def http_get_file(
  179. url: str,
  180. local_dir: str,
  181. file_name: str,
  182. cookies: CookieJar,
  183. headers: Optional[Dict[str, str]] = None,
  184. ):
  185. """
  186. Download remote file. Do not gobble up errors.
  187. This method is only used by snapshot_download, since the behavior is quite different with single file download
  188. TODO: consolidate with http_get_file() to avoild duplicate code
  189. Args:
  190. url(`str`):
  191. actual download url of the file
  192. local_dir(`str`):
  193. local directory where the downloaded file stores
  194. file_name(`str`):
  195. name of the file stored in `local_dir`
  196. cookies(`CookieJar`):
  197. cookies used to authentication the user, which is used for downloading private repos
  198. headers(`Optional[Dict[str, str]] = None`):
  199. http headers to carry necessary info when requesting the remote file
  200. """
  201. total = -1
  202. temp_file_manager = partial(
  203. tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
  204. with temp_file_manager() as temp_file:
  205. logger.info('downloading %s to %s', url, temp_file.name)
  206. headers = copy.deepcopy(headers)
  207. r = requests.get(url, stream=True, headers=headers, cookies=cookies)
  208. r.raise_for_status()
  209. content_length = r.headers.get('Content-Length')
  210. total = int(content_length) if content_length is not None else None
  211. progress = tqdm(
  212. unit='B',
  213. unit_scale=True,
  214. unit_divisor=1024,
  215. total=total,
  216. initial=0,
  217. desc='Downloading',
  218. )
  219. for chunk in r.iter_content(chunk_size=1024):
  220. if chunk: # filter out keep-alive new chunks
  221. progress.update(len(chunk))
  222. temp_file.write(chunk)
  223. progress.close()
  224. logger.info('storing %s in cache at %s', url, local_dir)
  225. downloaded_length = os.path.getsize(temp_file.name)
  226. if total != downloaded_length:
  227. os.remove(temp_file.name)
  228. msg = 'File %s download incomplete, content_length: %s but the \
  229. file downloaded length: %s, please download again' % (
  230. file_name, total, downloaded_length)
  231. logger.error(msg)
  232. raise FileDownloadError(msg)
  233. os.replace(temp_file.name, os.path.join(local_dir, file_name))