|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import tempfile
- import time
- import unittest
- import uuid
- from datetime import datetime
- from unittest import mock
-
- from modelscope import version
- from modelscope.hub.api import HubApi
- from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses,
- ModelVisibility)
- from modelscope.hub.errors import NotExistError
- from modelscope.hub.file_download import model_file_download
- from modelscope.hub.repository import Repository
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.utils.logger import get_logger
- from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
- TEST_MODEL_ORG)
-
- logger = get_logger()
- logger.setLevel('DEBUG')
- download_model_file_name = 'test.bin'
- download_model_file_name2 = 'test2.bin'
-
-
- class HubRevisionTest(unittest.TestCase):
-
- def setUp(self):
- self.api = HubApi()
- self.api.login(TEST_ACCESS_TOKEN1)
- self.model_name = 'rvr-%s' % (uuid.uuid4().hex)
- self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
- self.revision = 'v0.1_test_revision'
- self.revision2 = 'v0.2_test_revision'
- self.api.create_model(
- model_id=self.model_id,
- visibility=ModelVisibility.PUBLIC,
- license=Licenses.APACHE_V2,
- chinese_name=TEST_MODEL_CHINESE_NAME,
- )
- names_to_remove = {MODELSCOPE_SDK_DEBUG}
- self.modified_environ = {
- k: v
- for k, v in os.environ.items() if k not in names_to_remove
- }
-
- def tearDown(self):
- self.api.delete_model(model_id=self.model_id)
-
- def prepare_repo_data(self):
- temporary_dir = tempfile.mkdtemp()
- self.model_dir = os.path.join(temporary_dir, self.model_name)
- self.repo = Repository(self.model_dir, clone_from=self.model_id)
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name))
- self.repo.push('add model')
-
- def prepare_repo_data_and_tag(self):
- self.prepare_repo_data()
- self.repo.tag_and_push(self.revision, 'Test revision')
-
- def add_new_file_and_tag_to_repo(self):
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name2))
- self.repo.push('add new file')
- self.repo.tag_and_push(self.revision2, 'Test revision')
-
- def add_new_file_and_branch_to_repo(self, branch_name):
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name2))
- self.repo.push('add new file', remote_branch=branch_name)
-
- def test_dev_mode_default_master(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data() # no tag, default get master
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
-
- def test_dev_mode_specify_branch(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data() # no tag, default get master
- branch_name = 'test'
- self.add_new_file_and_branch_to_repo(branch_name)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=branch_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=branch_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
-
- def test_snapshot_download_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Second time: %s' % t2)
- # set
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- logger.info('Setting __release_datetime__ to: %s' % t1)
- version.__release_datetime__ = t1
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert not os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- version.__release_datetime__ = t2
- logger.info('Setting __release_datetime__ to: %s' % t2)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- finally:
- version.__release_datetime__ = release_datetime_backup
-
- def test_snapshot_download_revision_user_set_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Secnod time: %s' % t2)
- # set
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- logger.info('Setting __release_datetime__ to: %s' % t1)
- version.__release_datetime__ = t1
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert not os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- finally:
- version.__release_datetime__ = release_datetime_backup
-
- def test_file_download_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time stamp: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Second time: %s' % t2)
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- version.__release_datetime__ = t1
- logger.info('Setting __release_datetime__ to: %s' % t1)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- with self.assertRaises(NotExistError):
- model_file_download(
- self.model_id,
- download_model_file_name2,
- cache_dir=temp_cache_dir)
- version.__release_datetime__ = t2
- logger.info('Setting __release_datetime__ to: %s' % t2)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- file_path = model_file_download(
- self.model_id,
- download_model_file_name2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- finally:
- version.__release_datetime__ = release_datetime_backup
-
- def test_file_download_revision_user_set_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time stamp: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Second time: %s' % t2)
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- version.__release_datetime__ = t1
- logger.info('Setting __release_datetime__ to: %s' % t1)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- with self.assertRaises(NotExistError):
- model_file_download(
- self.model_id,
- download_model_file_name2,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- file_path = model_file_download(
- self.model_id,
- download_model_file_name2,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- finally:
- version.__release_datetime__ = release_datetime_backup
-
-
- if __name__ == '__main__':
- unittest.main()
|