|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import tempfile
- import time
- import unittest
- import uuid
- from datetime import datetime
- from unittest import mock
-
- from modelscope import version
- from modelscope.hub.api import HubApi
- from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses,
- ModelVisibility)
- from modelscope.hub.errors import NotExistError
- from modelscope.hub.file_download import model_file_download
- from modelscope.hub.repository import Repository
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.utils.logger import get_logger
- from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
- TEST_MODEL_ORG)
-
- logger = get_logger()
- logger.setLevel('DEBUG')
- download_model_file_name = 'test.bin'
- download_model_file_name2 = 'test2.bin'
-
-
- class HubRevisionTest(unittest.TestCase):
-
- def setUp(self):
- self.api = HubApi()
- self.api.login(TEST_ACCESS_TOKEN1)
- self.model_name = 'rvr-%s' % (uuid.uuid4().hex)
- self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
- self.revision = 'v0.1_test_revision'
- self.revision2 = 'v0.2_test_revision'
- self.api.create_model(
- model_id=self.model_id,
- visibility=ModelVisibility.PUBLIC,
- license=Licenses.APACHE_V2,
- chinese_name=TEST_MODEL_CHINESE_NAME,
- )
- names_to_remove = {MODELSCOPE_SDK_DEBUG}
- self.modified_environ = {
- k: v
- for k, v in os.environ.items() if k not in names_to_remove
- }
-
- def tearDown(self):
- self.api.delete_model(model_id=self.model_id)
-
- def prepare_repo_data(self):
- temporary_dir = tempfile.mkdtemp()
- self.model_dir = os.path.join(temporary_dir, self.model_name)
- self.repo = Repository(self.model_dir, clone_from=self.model_id)
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name))
- self.repo.push('add model')
-
- def prepare_repo_data_and_tag(self):
- self.prepare_repo_data()
- self.repo.tag_and_push(self.revision, 'Test revision')
-
- def add_new_file_and_tag_to_repo(self):
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name2))
- self.repo.push('add new file')
- self.repo.tag_and_push(self.revision2, 'Test revision')
-
- def add_new_file_and_branch_to_repo(self, branch_name):
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name2))
- self.repo.push('add new file', remote_branch=branch_name)
-
- def test_dev_mode_default_master(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data() # no tag, default get master
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
-
- def test_dev_mode_specify_branch(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data() # no tag, default get master
- branch_name = 'test'
- self.add_new_file_and_branch_to_repo(branch_name)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=branch_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=branch_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
-
- def test_snapshot_download_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Second time: %s' % t2)
- # set
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- logger.info('Setting __release_datetime__ to: %s' % t1)
- version.__release_datetime__ = t1
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert not os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- version.__release_datetime__ = t2
- logger.info('Setting __release_datetime__ to: %s' % t2)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- finally:
- version.__release_datetime__ = release_datetime_backup
-
- def test_snapshot_download_revision_user_set_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Secnod time: %s' % t2)
- # set
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- logger.info('Setting __release_datetime__ to: %s' % t1)
- version.__release_datetime__ = t1
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert not os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- finally:
- version.__release_datetime__ = release_datetime_backup
-
- def test_file_download_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time stamp: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Second time: %s' % t2)
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- version.__release_datetime__ = t1
- logger.info('Setting __release_datetime__ to: %s' % t1)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- with self.assertRaises(NotExistError):
- model_file_download(
- self.model_id,
- download_model_file_name2,
- cache_dir=temp_cache_dir)
- version.__release_datetime__ = t2
- logger.info('Setting __release_datetime__ to: %s' % t2)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- file_path = model_file_download(
- self.model_id,
- download_model_file_name2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- finally:
- version.__release_datetime__ = release_datetime_backup
-
- def test_file_download_revision_user_set_revision(self):
- with mock.patch.dict(os.environ, self.modified_environ, clear=True):
- self.prepare_repo_data_and_tag()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time stamp: %s' % t1)
- time.sleep(10)
- self.add_new_file_and_tag_to_repo()
- t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('Second time: %s' % t2)
- release_datetime_backup = version.__release_datetime__
- logger.info('Origin __release_datetime__: %s'
- % version.__release_datetime__)
- try:
- version.__release_datetime__ = t1
- logger.info('Setting __release_datetime__ to: %s' % t1)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- with self.assertRaises(NotExistError):
- model_file_download(
- self.model_id,
- download_model_file_name2,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- file_path = model_file_download(
- self.model_id,
- download_model_file_name2,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- finally:
- version.__release_datetime__ = release_datetime_backup
-
-
- if __name__ == '__main__':
- unittest.main()
|