添加文件下载完整性验证 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9700279 * [to #43913168]fix: add file download integrity checkmaster
@@ -4,7 +4,7 @@ DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DO | |||
DEFAULT_MODELSCOPE_GROUP = 'damo' | |||
MODEL_ID_SEPARATOR = '/' | |||
FILE_HASH = 'Sha256' | |||
LOGGER_NAME = 'ModelScopeHub' | |||
DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | |||
API_RESPONSE_FIELD_DATA = 'Data' | |||
@@ -23,6 +23,14 @@ class NotLoginException(Exception): | |||
pass | |||
class FileIntegrityError(Exception): | |||
pass | |||
class FileDownloadError(Exception): | |||
pass | |||
def is_ok(rsp): | |||
""" Check the request is ok | |||
@@ -16,10 +16,11 @@ from modelscope import __version__ | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .errors import NotExistError | |||
from .constants import FILE_HASH | |||
from .errors import FileDownloadError, NotExistError | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import (get_cache_dir, get_endpoint, | |||
model_id_to_group_owner_name) | |||
from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
get_endpoint, model_id_to_group_owner_name) | |||
SESSION_ID = uuid4().hex | |||
logger = get_logger() | |||
@@ -143,24 +144,29 @@ def model_file_download( | |||
# we need to download again | |||
url_to_download = get_file_download_url(model_id, file_path, revision) | |||
file_to_download_info = { | |||
'Path': file_path, | |||
'Path': | |||
file_path, | |||
'Revision': | |||
revision if is_commit_id else file_to_download_info['Revision'] | |||
revision if is_commit_id else file_to_download_info['Revision'], | |||
FILE_HASH: | |||
None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||
file_to_download_info[FILE_HASH] | |||
} | |||
# Prevent parallel downloads of the same file with a lock. | |||
lock_path = cache.get_root_location() + '.lock' | |||
with FileLock(lock_path): | |||
temp_file_name = next(tempfile._get_candidate_names()) | |||
http_get_file( | |||
url_to_download, | |||
temporary_cache_dir, | |||
temp_file_name, | |||
headers=headers, | |||
cookies=None if cookies is None else cookies.get_dict()) | |||
return cache.put_file( | |||
file_to_download_info, | |||
os.path.join(temporary_cache_dir, temp_file_name)) | |||
temp_file_name = next(tempfile._get_candidate_names()) | |||
http_get_file( | |||
url_to_download, | |||
temporary_cache_dir, | |||
temp_file_name, | |||
headers=headers, | |||
cookies=None if cookies is None else cookies.get_dict()) | |||
temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) | |||
# for download with commit we can't get Sha256 | |||
if file_to_download_info[FILE_HASH] is not None: | |||
file_integrity_validation(temp_file_path, | |||
file_to_download_info[FILE_HASH]) | |||
return cache.put_file(file_to_download_info, | |||
os.path.join(temporary_cache_dir, temp_file_name)) | |||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
@@ -222,6 +228,7 @@ def http_get_file( | |||
http headers to carry necessary info when requesting the remote file | |||
""" | |||
total = -1 | |||
temp_file_manager = partial( | |||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | |||
@@ -250,4 +257,12 @@ def http_get_file( | |||
progress.close() | |||
logger.info('storing %s in cache at %s', url, local_dir) | |||
downloaded_length = os.path.getsize(temp_file.name) | |||
if total != downloaded_length: | |||
os.remove(temp_file.name) | |||
msg = 'File %s download incomplete, content_length: %s but the \ | |||
file downloaded length: %s, please download again' % ( | |||
file_name, total, downloaded_length) | |||
logger.error(msg) | |||
raise FileDownloadError(msg) | |||
os.replace(temp_file.name, os.path.join(local_dir, file_name)) |
@@ -6,11 +6,13 @@ from typing import Dict, Optional, Union | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import FILE_HASH | |||
from .errors import NotExistError | |||
from .file_download import (get_file_download_url, http_get_file, | |||
http_user_agent) | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import get_cache_dir, model_id_to_group_owner_name | |||
from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
model_id_to_group_owner_name) | |||
logger = get_logger() | |||
@@ -127,9 +129,11 @@ def snapshot_download(model_id: str, | |||
file_name=model_file['Name'], | |||
headers=headers, | |||
cookies=cookies) | |||
# check file integrity | |||
temp_file = os.path.join(temp_cache_dir, model_file['Name']) | |||
if FILE_HASH in model_file: | |||
file_integrity_validation(temp_file, model_file[FILE_HASH]) | |||
# put file to cache | |||
cache.put_file( | |||
model_file, os.path.join(temp_cache_dir, | |||
model_file['Name'])) | |||
cache.put_file(model_file, temp_file) | |||
return os.path.join(cache.get_root_location()) |
@@ -1,10 +1,15 @@ | |||
import hashlib | |||
import os | |||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
DEFAULT_MODELSCOPE_GROUP, | |||
MODEL_ID_SEPARATOR, | |||
MODELSCOPE_URL_SCHEME) | |||
from modelscope.hub.errors import FileIntegrityError | |||
from modelscope.utils.file_utils import get_default_cache_dir | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
def model_id_to_group_owner_name(model_id): | |||
@@ -31,3 +36,34 @@ def get_endpoint(): | |||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||
DEFAULT_MODELSCOPE_DOMAIN) | |||
return MODELSCOPE_URL_SCHEME + modelscope_domain | |||
def compute_hash(file_path): | |||
BUFFER_SIZE = 1024 * 64 # 64k buffer size | |||
sha256_hash = hashlib.sha256() | |||
with open(file_path, 'rb') as f: | |||
while True: | |||
data = f.read(BUFFER_SIZE) | |||
if not data: | |||
break | |||
sha256_hash.update(data) | |||
return sha256_hash.hexdigest() | |||
def file_integrity_validation(file_path, expected_sha256): | |||
"""Validate the file hash is expected, if not, delete the file | |||
Args: | |||
file_path (str): The file to validate | |||
expected_sha256 (str): The expected sha256 hash | |||
Raises: | |||
FileIntegrityError: If file_path hash is not expected. | |||
""" | |||
file_sha256 = compute_hash(file_path) | |||
if not file_sha256 == expected_sha256: | |||
os.remove(file_path) | |||
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path | |||
logger.error(msg) | |||
raise FileIntegrityError(msg) |