|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import tempfile
- import unittest
- import uuid
- from datetime import datetime
-
- from modelscope.hub.api import HubApi
- from modelscope.hub.constants import Licenses, ModelVisibility
- from modelscope.hub.errors import NotExistError, NoValidRevisionError
- from modelscope.hub.file_download import model_file_download
- from modelscope.hub.repository import Repository
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.utils.constant import ModelFile
- from modelscope.utils.logger import get_logger
- from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
- TEST_MODEL_ORG)
-
- logger = get_logger()
- logger.setLevel('DEBUG')
- download_model_file_name = 'test.bin'
- download_model_file_name2 = 'test2.bin'
-
-
- class HubRevisionTest(unittest.TestCase):
-
- def setUp(self):
- self.api = HubApi()
- self.api.login(TEST_ACCESS_TOKEN1)
- self.model_name = 'rv-%s' % (uuid.uuid4().hex)
- self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
- self.revision = 'v0.1_test_revision'
- self.revision2 = 'v0.2_test_revision'
- self.api.create_model(
- model_id=self.model_id,
- visibility=ModelVisibility.PUBLIC,
- license=Licenses.APACHE_V2,
- chinese_name=TEST_MODEL_CHINESE_NAME,
- )
-
- def tearDown(self):
- self.api.delete_model(model_id=self.model_id)
-
- def prepare_repo_data(self):
- temporary_dir = tempfile.mkdtemp()
- self.model_dir = os.path.join(temporary_dir, self.model_name)
- self.repo = Repository(self.model_dir, clone_from=self.model_id)
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name))
- self.repo.push('add model')
- self.repo.tag_and_push(self.revision, 'Test revision')
-
- def test_no_tag(self):
- with self.assertRaises(NoValidRevisionError):
- snapshot_download(self.model_id, None)
-
- with self.assertRaises(NoValidRevisionError):
- model_file_download(self.model_id, ModelFile.README)
-
- def test_with_only_one_tag(self):
- self.prepare_repo_data()
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id, cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id, ModelFile.README, cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
-
- def add_new_file_and_tag(self):
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name2))
- self.repo.push('add new file')
- self.repo.tag_and_push(self.revision2, 'Test revision')
-
- def test_snapshot_download_different_revision(self):
- self.prepare_repo_data()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time stamp: %s' % t1)
- snapshot_path = snapshot_download(self.model_id, self.revision)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- self.add_new_file_and_tag()
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert not os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- snapshot_path = snapshot_download(
- self.model_id,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name))
- assert os.path.exists(
- os.path.join(snapshot_path, download_model_file_name2))
-
- def test_file_download_different_revision(self):
- self.prepare_repo_data()
- t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
- logger.info('First time stamp: %s' % t1)
- file_path = model_file_download(self.model_id,
- download_model_file_name,
- self.revision)
- assert os.path.exists(file_path)
- self.add_new_file_and_tag()
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=self.revision,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
- with self.assertRaises(NotExistError):
- model_file_download(
- self.model_id,
- download_model_file_name2,
- revision=self.revision,
- cache_dir=temp_cache_dir)
-
- with tempfile.TemporaryDirectory() as temp_cache_dir:
- file_path = model_file_download(
- self.model_id,
- download_model_file_name,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- print('Downloaded file path: %s' % file_path)
- assert os.path.exists(file_path)
- file_path = model_file_download(
- self.model_id,
- download_model_file_name2,
- revision=self.revision2,
- cache_dir=temp_cache_dir)
- assert os.path.exists(file_path)
-
-
- if __name__ == '__main__':
- unittest.main()
|