diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 094e9063..702251e3 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -4,7 +4,7 @@ DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DO DEFAULT_MODELSCOPE_GROUP = 'damo' MODEL_ID_SEPARATOR = '/' - +FILE_HASH = 'Sha256' LOGGER_NAME = 'ModelScopeHub' DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' API_RESPONSE_FIELD_DATA = 'Data' diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index a5056c1d..ecd4e1da 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -23,6 +23,14 @@ class NotLoginException(Exception): pass +class FileIntegrityError(Exception): + pass + + +class FileDownloadError(Exception): + pass + + def is_ok(rsp): """ Check the request is ok diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index d0b8a102..5f15272c 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -16,10 +16,11 @@ from modelscope import __version__ from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.logger import get_logger from .api import HubApi, ModelScopeConfig -from .errors import NotExistError +from .constants import FILE_HASH +from .errors import FileDownloadError, NotExistError from .utils.caching import ModelFileSystemCache -from .utils.utils import (get_cache_dir, get_endpoint, - model_id_to_group_owner_name) +from .utils.utils import (file_integrity_validation, get_cache_dir, + get_endpoint, model_id_to_group_owner_name) SESSION_ID = uuid4().hex logger = get_logger() @@ -143,24 +144,29 @@ def model_file_download( # we need to download again url_to_download = get_file_download_url(model_id, file_path, revision) file_to_download_info = { - 'Path': file_path, + 'Path': + file_path, 'Revision': - revision if is_commit_id else file_to_download_info['Revision'] + revision if is_commit_id else file_to_download_info['Revision'], + FILE_HASH: + None if (is_commit_id or FILE_HASH not in file_to_download_info) else + file_to_download_info[FILE_HASH] } - # Prevent parallel downloads of the same file with a lock. - lock_path = cache.get_root_location() + '.lock' - - with FileLock(lock_path): - temp_file_name = next(tempfile._get_candidate_names()) - http_get_file( - url_to_download, - temporary_cache_dir, - temp_file_name, - headers=headers, - cookies=None if cookies is None else cookies.get_dict()) - return cache.put_file( - file_to_download_info, - os.path.join(temporary_cache_dir, temp_file_name)) + + temp_file_name = next(tempfile._get_candidate_names()) + http_get_file( + url_to_download, + temporary_cache_dir, + temp_file_name, + headers=headers, + cookies=None if cookies is None else cookies.get_dict()) + temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) + # for download with commit we can't get Sha256 + if file_to_download_info[FILE_HASH] is not None: + file_integrity_validation(temp_file_path, + file_to_download_info[FILE_HASH]) + return cache.put_file(file_to_download_info, + os.path.join(temporary_cache_dir, temp_file_name)) def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: @@ -222,6 +228,7 @@ def http_get_file( http headers to carry necessary info when requesting the remote file """ + total = -1 temp_file_manager = partial( tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) @@ -250,4 +257,12 @@ def http_get_file( progress.close() logger.info('storing %s in cache at %s', url, local_dir) + downloaded_length = os.path.getsize(temp_file.name) + if total != downloaded_length: + os.remove(temp_file.name) + msg = 'File %s download incomplete, content_length: %s but the \ + file downloaded length: %s, please download again' % ( + file_name, total, downloaded_length) + logger.error(msg) + raise FileDownloadError(msg) os.replace(temp_file.name, os.path.join(local_dir, file_name)) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 5f9548e9..c63d8956 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -6,11 +6,13 @@ from typing import Dict, Optional, Union from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.logger import get_logger from .api import HubApi, ModelScopeConfig +from .constants import FILE_HASH from .errors import NotExistError from .file_download import (get_file_download_url, http_get_file, http_user_agent) from .utils.caching import ModelFileSystemCache -from .utils.utils import get_cache_dir, model_id_to_group_owner_name +from .utils.utils import (file_integrity_validation, get_cache_dir, + model_id_to_group_owner_name) logger = get_logger() @@ -127,9 +129,11 @@ def snapshot_download(model_id: str, file_name=model_file['Name'], headers=headers, cookies=cookies) + # check file integrity + temp_file = os.path.join(temp_cache_dir, model_file['Name']) + if FILE_HASH in model_file: + file_integrity_validation(temp_file, model_file[FILE_HASH]) # put file to cache - cache.put_file( - model_file, os.path.join(temp_cache_dir, - model_file['Name'])) + cache.put_file(model_file, temp_file) return os.path.join(cache.get_root_location()) diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index fff88cca..1a55c9f9 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -1,10 +1,15 @@ +import hashlib import os from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR, MODELSCOPE_URL_SCHEME) +from modelscope.hub.errors import FileIntegrityError from modelscope.utils.file_utils import get_default_cache_dir +from modelscope.utils.logger import get_logger + +logger = get_logger() def model_id_to_group_owner_name(model_id): @@ -31,3 +36,34 @@ def get_endpoint(): modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', DEFAULT_MODELSCOPE_DOMAIN) return MODELSCOPE_URL_SCHEME + modelscope_domain + + +def compute_hash(file_path): + BUFFER_SIZE = 1024 * 64 # 64k buffer size + sha256_hash = hashlib.sha256() + with open(file_path, 'rb') as f: + while True: + data = f.read(BUFFER_SIZE) + if not data: + break + sha256_hash.update(data) + return sha256_hash.hexdigest() + + +def file_integrity_validation(file_path, expected_sha256): + """Validate the file hash is expected, if not, delete the file + + Args: + file_path (str): The file to validate + expected_sha256 (str): The expected sha256 hash + + Raises: + FileIntegrityError: If file_path hash is not expected. + + """ + file_sha256 = compute_hash(file_path) + if not file_sha256 == expected_sha256: + os.remove(file_path) + msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path + logger.error(msg) + raise FileIntegrityError(msg)