Browse Source

[to #43913168]fix: add file download integrity check

添加文件下载完整性验证
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9700279

    * [to #43913168]fix: add file download integrity check
master
mulin.lyh 3 years ago
parent
commit
6cf0a56ade
5 changed files with 87 additions and 24 deletions
  1. +1
    -1
      modelscope/hub/constants.py
  2. +8
    -0
      modelscope/hub/errors.py
  3. +34
    -19
      modelscope/hub/file_download.py
  4. +8
    -4
      modelscope/hub/snapshot_download.py
  5. +36
    -0
      modelscope/hub/utils/utils.py

+ 1
- 1
modelscope/hub/constants.py View File

@@ -4,7 +4,7 @@ DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DO

DEFAULT_MODELSCOPE_GROUP = 'damo'
MODEL_ID_SEPARATOR = '/'
FILE_HASH = 'Sha256'
LOGGER_NAME = 'ModelScopeHub'
DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials'
API_RESPONSE_FIELD_DATA = 'Data'


+ 8
- 0
modelscope/hub/errors.py View File

@@ -23,6 +23,14 @@ class NotLoginException(Exception):
pass


class FileIntegrityError(Exception):
pass


class FileDownloadError(Exception):
pass


def is_ok(rsp):
""" Check the request is ok



+ 34
- 19
modelscope/hub/file_download.py View File

@@ -16,10 +16,11 @@ from modelscope import __version__
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.logger import get_logger
from .api import HubApi, ModelScopeConfig
from .errors import NotExistError
from .constants import FILE_HASH
from .errors import FileDownloadError, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (get_cache_dir, get_endpoint,
model_id_to_group_owner_name)
from .utils.utils import (file_integrity_validation, get_cache_dir,
get_endpoint, model_id_to_group_owner_name)

SESSION_ID = uuid4().hex
logger = get_logger()
@@ -143,24 +144,29 @@ def model_file_download(
# we need to download again
url_to_download = get_file_download_url(model_id, file_path, revision)
file_to_download_info = {
'Path': file_path,
'Path':
file_path,
'Revision':
revision if is_commit_id else file_to_download_info['Revision']
revision if is_commit_id else file_to_download_info['Revision'],
FILE_HASH:
None if (is_commit_id or FILE_HASH not in file_to_download_info) else
file_to_download_info[FILE_HASH]
}
# Prevent parallel downloads of the same file with a lock.
lock_path = cache.get_root_location() + '.lock'

with FileLock(lock_path):
temp_file_name = next(tempfile._get_candidate_names())
http_get_file(
url_to_download,
temporary_cache_dir,
temp_file_name,
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
return cache.put_file(
file_to_download_info,
os.path.join(temporary_cache_dir, temp_file_name))

temp_file_name = next(tempfile._get_candidate_names())
http_get_file(
url_to_download,
temporary_cache_dir,
temp_file_name,
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
temp_file_path = os.path.join(temporary_cache_dir, temp_file_name)
# for download with commit we can't get Sha256
if file_to_download_info[FILE_HASH] is not None:
file_integrity_validation(temp_file_path,
file_to_download_info[FILE_HASH])
return cache.put_file(file_to_download_info,
os.path.join(temporary_cache_dir, temp_file_name))


def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
@@ -222,6 +228,7 @@ def http_get_file(
http headers to carry necessary info when requesting the remote file

"""
total = -1
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)

@@ -250,4 +257,12 @@ def http_get_file(
progress.close()

logger.info('storing %s in cache at %s', url, local_dir)
downloaded_length = os.path.getsize(temp_file.name)
if total != downloaded_length:
os.remove(temp_file.name)
msg = 'File %s download incomplete, content_length: %s but the \
file downloaded length: %s, please download again' % (
file_name, total, downloaded_length)
logger.error(msg)
raise FileDownloadError(msg)
os.replace(temp_file.name, os.path.join(local_dir, file_name))

+ 8
- 4
modelscope/hub/snapshot_download.py View File

@@ -6,11 +6,13 @@ from typing import Dict, Optional, Union
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.logger import get_logger
from .api import HubApi, ModelScopeConfig
from .constants import FILE_HASH
from .errors import NotExistError
from .file_download import (get_file_download_url, http_get_file,
http_user_agent)
from .utils.caching import ModelFileSystemCache
from .utils.utils import get_cache_dir, model_id_to_group_owner_name
from .utils.utils import (file_integrity_validation, get_cache_dir,
model_id_to_group_owner_name)

logger = get_logger()

@@ -127,9 +129,11 @@ def snapshot_download(model_id: str,
file_name=model_file['Name'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temp_cache_dir, model_file['Name'])
if FILE_HASH in model_file:
file_integrity_validation(temp_file, model_file[FILE_HASH])
# put file to cache
cache.put_file(
model_file, os.path.join(temp_cache_dir,
model_file['Name']))
cache.put_file(model_file, temp_file)

return os.path.join(cache.get_root_location())

+ 36
- 0
modelscope/hub/utils/utils.py View File

@@ -1,10 +1,15 @@
import hashlib
import os

from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR,
MODELSCOPE_URL_SCHEME)
from modelscope.hub.errors import FileIntegrityError
from modelscope.utils.file_utils import get_default_cache_dir
from modelscope.utils.logger import get_logger

logger = get_logger()


def model_id_to_group_owner_name(model_id):
@@ -31,3 +36,34 @@ def get_endpoint():
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
DEFAULT_MODELSCOPE_DOMAIN)
return MODELSCOPE_URL_SCHEME + modelscope_domain


def compute_hash(file_path):
BUFFER_SIZE = 1024 * 64 # 64k buffer size
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
while True:
data = f.read(BUFFER_SIZE)
if not data:
break
sha256_hash.update(data)
return sha256_hash.hexdigest()


def file_integrity_validation(file_path, expected_sha256):
"""Validate the file hash is expected, if not, delete the file

Args:
file_path (str): The file to validate
expected_sha256 (str): The expected sha256 hash

Raises:
FileIntegrityError: If file_path hash is not expected.

"""
file_sha256 = compute_hash(file_path)
if not file_sha256 == expected_sha256:
os.remove(file_path)
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
logger.error(msg)
raise FileIntegrityError(msg)

Loading…
Cancel
Save