添加文件下载完整性验证 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9700279 * [to #43913168]fix: add file download integrity checkmaster
@@ -4,7 +4,7 @@ DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DO | |||||
DEFAULT_MODELSCOPE_GROUP = 'damo' | DEFAULT_MODELSCOPE_GROUP = 'damo' | ||||
MODEL_ID_SEPARATOR = '/' | MODEL_ID_SEPARATOR = '/' | ||||
FILE_HASH = 'Sha256' | |||||
LOGGER_NAME = 'ModelScopeHub' | LOGGER_NAME = 'ModelScopeHub' | ||||
DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | ||||
API_RESPONSE_FIELD_DATA = 'Data' | API_RESPONSE_FIELD_DATA = 'Data' | ||||
@@ -23,6 +23,14 @@ class NotLoginException(Exception): | |||||
pass | pass | ||||
class FileIntegrityError(Exception): | |||||
pass | |||||
class FileDownloadError(Exception): | |||||
pass | |||||
def is_ok(rsp): | def is_ok(rsp): | ||||
""" Check the request is ok | """ Check the request is ok | ||||
@@ -16,10 +16,11 @@ from modelscope import __version__ | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .api import HubApi, ModelScopeConfig | from .api import HubApi, ModelScopeConfig | ||||
from .errors import NotExistError | |||||
from .constants import FILE_HASH | |||||
from .errors import FileDownloadError, NotExistError | |||||
from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
from .utils.utils import (get_cache_dir, get_endpoint, | |||||
model_id_to_group_owner_name) | |||||
from .utils.utils import (file_integrity_validation, get_cache_dir, | |||||
get_endpoint, model_id_to_group_owner_name) | |||||
SESSION_ID = uuid4().hex | SESSION_ID = uuid4().hex | ||||
logger = get_logger() | logger = get_logger() | ||||
@@ -143,24 +144,29 @@ def model_file_download( | |||||
# we need to download again | # we need to download again | ||||
url_to_download = get_file_download_url(model_id, file_path, revision) | url_to_download = get_file_download_url(model_id, file_path, revision) | ||||
file_to_download_info = { | file_to_download_info = { | ||||
'Path': file_path, | |||||
'Path': | |||||
file_path, | |||||
'Revision': | 'Revision': | ||||
revision if is_commit_id else file_to_download_info['Revision'] | |||||
revision if is_commit_id else file_to_download_info['Revision'], | |||||
FILE_HASH: | |||||
None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||||
file_to_download_info[FILE_HASH] | |||||
} | } | ||||
# Prevent parallel downloads of the same file with a lock. | |||||
lock_path = cache.get_root_location() + '.lock' | |||||
with FileLock(lock_path): | |||||
temp_file_name = next(tempfile._get_candidate_names()) | |||||
http_get_file( | |||||
url_to_download, | |||||
temporary_cache_dir, | |||||
temp_file_name, | |||||
headers=headers, | |||||
cookies=None if cookies is None else cookies.get_dict()) | |||||
return cache.put_file( | |||||
file_to_download_info, | |||||
os.path.join(temporary_cache_dir, temp_file_name)) | |||||
temp_file_name = next(tempfile._get_candidate_names()) | |||||
http_get_file( | |||||
url_to_download, | |||||
temporary_cache_dir, | |||||
temp_file_name, | |||||
headers=headers, | |||||
cookies=None if cookies is None else cookies.get_dict()) | |||||
temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) | |||||
# for download with commit we can't get Sha256 | |||||
if file_to_download_info[FILE_HASH] is not None: | |||||
file_integrity_validation(temp_file_path, | |||||
file_to_download_info[FILE_HASH]) | |||||
return cache.put_file(file_to_download_info, | |||||
os.path.join(temporary_cache_dir, temp_file_name)) | |||||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | ||||
@@ -222,6 +228,7 @@ def http_get_file( | |||||
http headers to carry necessary info when requesting the remote file | http headers to carry necessary info when requesting the remote file | ||||
""" | """ | ||||
total = -1 | |||||
temp_file_manager = partial( | temp_file_manager = partial( | ||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | ||||
@@ -250,4 +257,12 @@ def http_get_file( | |||||
progress.close() | progress.close() | ||||
logger.info('storing %s in cache at %s', url, local_dir) | logger.info('storing %s in cache at %s', url, local_dir) | ||||
downloaded_length = os.path.getsize(temp_file.name) | |||||
if total != downloaded_length: | |||||
os.remove(temp_file.name) | |||||
msg = 'File %s download incomplete, content_length: %s but the \ | |||||
file downloaded length: %s, please download again' % ( | |||||
file_name, total, downloaded_length) | |||||
logger.error(msg) | |||||
raise FileDownloadError(msg) | |||||
os.replace(temp_file.name, os.path.join(local_dir, file_name)) | os.replace(temp_file.name, os.path.join(local_dir, file_name)) |
@@ -6,11 +6,13 @@ from typing import Dict, Optional, Union | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .api import HubApi, ModelScopeConfig | from .api import HubApi, ModelScopeConfig | ||||
from .constants import FILE_HASH | |||||
from .errors import NotExistError | from .errors import NotExistError | ||||
from .file_download import (get_file_download_url, http_get_file, | from .file_download import (get_file_download_url, http_get_file, | ||||
http_user_agent) | http_user_agent) | ||||
from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
from .utils.utils import get_cache_dir, model_id_to_group_owner_name | |||||
from .utils.utils import (file_integrity_validation, get_cache_dir, | |||||
model_id_to_group_owner_name) | |||||
logger = get_logger() | logger = get_logger() | ||||
@@ -127,9 +129,11 @@ def snapshot_download(model_id: str, | |||||
file_name=model_file['Name'], | file_name=model_file['Name'], | ||||
headers=headers, | headers=headers, | ||||
cookies=cookies) | cookies=cookies) | ||||
# check file integrity | |||||
temp_file = os.path.join(temp_cache_dir, model_file['Name']) | |||||
if FILE_HASH in model_file: | |||||
file_integrity_validation(temp_file, model_file[FILE_HASH]) | |||||
# put file to cache | # put file to cache | ||||
cache.put_file( | |||||
model_file, os.path.join(temp_cache_dir, | |||||
model_file['Name'])) | |||||
cache.put_file(model_file, temp_file) | |||||
return os.path.join(cache.get_root_location()) | return os.path.join(cache.get_root_location()) |
@@ -1,10 +1,15 @@ | |||||
import hashlib | |||||
import os | import os | ||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | ||||
DEFAULT_MODELSCOPE_GROUP, | DEFAULT_MODELSCOPE_GROUP, | ||||
MODEL_ID_SEPARATOR, | MODEL_ID_SEPARATOR, | ||||
MODELSCOPE_URL_SCHEME) | MODELSCOPE_URL_SCHEME) | ||||
from modelscope.hub.errors import FileIntegrityError | |||||
from modelscope.utils.file_utils import get_default_cache_dir | from modelscope.utils.file_utils import get_default_cache_dir | ||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
def model_id_to_group_owner_name(model_id): | def model_id_to_group_owner_name(model_id): | ||||
@@ -31,3 +36,34 @@ def get_endpoint(): | |||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | ||||
DEFAULT_MODELSCOPE_DOMAIN) | DEFAULT_MODELSCOPE_DOMAIN) | ||||
return MODELSCOPE_URL_SCHEME + modelscope_domain | return MODELSCOPE_URL_SCHEME + modelscope_domain | ||||
def compute_hash(file_path): | |||||
BUFFER_SIZE = 1024 * 64 # 64k buffer size | |||||
sha256_hash = hashlib.sha256() | |||||
with open(file_path, 'rb') as f: | |||||
while True: | |||||
data = f.read(BUFFER_SIZE) | |||||
if not data: | |||||
break | |||||
sha256_hash.update(data) | |||||
return sha256_hash.hexdigest() | |||||
def file_integrity_validation(file_path, expected_sha256): | |||||
"""Validate the file hash is expected, if not, delete the file | |||||
Args: | |||||
file_path (str): The file to validate | |||||
expected_sha256 (str): The expected sha256 hash | |||||
Raises: | |||||
FileIntegrityError: If file_path hash is not expected. | |||||
""" | |||||
file_sha256 = compute_hash(file_path) | |||||
if not file_sha256 == expected_sha256: | |||||
os.remove(file_path) | |||||
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path | |||||
logger.error(msg) | |||||
raise FileIntegrityError(msg) |