|
- import hashlib
- import os
- from typing import Optional
-
- from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT,
- DEFAULT_MODELSCOPE_DOMAIN,
- DEFAULT_MODELSCOPE_GROUP,
- MODEL_ID_SEPARATOR,
- MODELSCOPE_URL_SCHEME)
- from modelscope.hub.errors import FileIntegrityError
- from modelscope.utils.file_utils import get_default_cache_dir
- from modelscope.utils.logger import get_logger
-
- logger = get_logger()
-
-
- def model_id_to_group_owner_name(model_id):
- if MODEL_ID_SEPARATOR in model_id:
- group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
- name = model_id.split(MODEL_ID_SEPARATOR)[1]
- else:
- group_or_owner = DEFAULT_MODELSCOPE_GROUP
- name = model_id
- return group_or_owner, name
-
-
- def get_cache_dir(model_id: Optional[str] = None):
- """
- cache dir precedence:
- function parameter > enviroment > ~/.cache/modelscope/hub
- """
- default_cache_dir = get_default_cache_dir()
- base_path = os.getenv('MODELSCOPE_CACHE',
- os.path.join(default_cache_dir, 'hub'))
- return base_path if model_id is None else os.path.join(
- base_path, model_id + '/')
-
-
- def get_endpoint():
- modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
- DEFAULT_MODELSCOPE_DOMAIN)
- return MODELSCOPE_URL_SCHEME + modelscope_domain
-
-
- def get_dataset_hub_endpoint():
- return os.environ.get('HUB_DATASET_ENDPOINT',
- DEFAULT_MODELSCOPE_DATA_ENDPOINT)
-
-
- def compute_hash(file_path):
- BUFFER_SIZE = 1024 * 64 # 64k buffer size
- sha256_hash = hashlib.sha256()
- with open(file_path, 'rb') as f:
- while True:
- data = f.read(BUFFER_SIZE)
- if not data:
- break
- sha256_hash.update(data)
- return sha256_hash.hexdigest()
-
-
- def file_integrity_validation(file_path, expected_sha256):
- """Validate the file hash is expected, if not, delete the file
-
- Args:
- file_path (str): The file to validate
- expected_sha256 (str): The expected sha256 hash
-
- Raises:
- FileIntegrityError: If file_path hash is not expected.
-
- """
- file_sha256 = compute_hash(file_path)
- if not file_sha256 == expected_sha256:
- os.remove(file_path)
- msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
- logger.error(msg)
- raise FileIntegrityError(msg)
|