You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hub_revision.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import tempfile
  4. import unittest
  5. import uuid
  6. from datetime import datetime
  7. from modelscope.hub.api import HubApi
  8. from modelscope.hub.constants import Licenses, ModelVisibility
  9. from modelscope.hub.errors import NotExistError, NoValidRevisionError
  10. from modelscope.hub.file_download import model_file_download
  11. from modelscope.hub.repository import Repository
  12. from modelscope.hub.snapshot_download import snapshot_download
  13. from modelscope.utils.constant import ModelFile
  14. from modelscope.utils.logger import get_logger
  15. from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
  16. TEST_MODEL_ORG)
  17. logger = get_logger()
  18. logger.setLevel('DEBUG')
  19. download_model_file_name = 'test.bin'
  20. download_model_file_name2 = 'test2.bin'
  21. class HubRevisionTest(unittest.TestCase):
  22. def setUp(self):
  23. self.api = HubApi()
  24. self.api.login(TEST_ACCESS_TOKEN1)
  25. self.model_name = 'rv-%s' % (uuid.uuid4().hex)
  26. self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
  27. self.revision = 'v0.1_test_revision'
  28. self.revision2 = 'v0.2_test_revision'
  29. self.api.create_model(
  30. model_id=self.model_id,
  31. visibility=ModelVisibility.PUBLIC,
  32. license=Licenses.APACHE_V2,
  33. chinese_name=TEST_MODEL_CHINESE_NAME,
  34. )
  35. def tearDown(self):
  36. self.api.delete_model(model_id=self.model_id)
  37. def prepare_repo_data(self):
  38. temporary_dir = tempfile.mkdtemp()
  39. self.model_dir = os.path.join(temporary_dir, self.model_name)
  40. self.repo = Repository(self.model_dir, clone_from=self.model_id)
  41. os.system("echo 'testtest'>%s"
  42. % os.path.join(self.model_dir, download_model_file_name))
  43. self.repo.push('add model')
  44. self.repo.tag_and_push(self.revision, 'Test revision')
  45. def test_no_tag(self):
  46. with self.assertRaises(NoValidRevisionError):
  47. snapshot_download(self.model_id, None)
  48. with self.assertRaises(NoValidRevisionError):
  49. model_file_download(self.model_id, ModelFile.README)
  50. def test_with_only_one_tag(self):
  51. self.prepare_repo_data()
  52. with tempfile.TemporaryDirectory() as temp_cache_dir:
  53. snapshot_path = snapshot_download(
  54. self.model_id, cache_dir=temp_cache_dir)
  55. assert os.path.exists(
  56. os.path.join(snapshot_path, download_model_file_name))
  57. with tempfile.TemporaryDirectory() as temp_cache_dir:
  58. file_path = model_file_download(
  59. self.model_id, ModelFile.README, cache_dir=temp_cache_dir)
  60. assert os.path.exists(file_path)
  61. def add_new_file_and_tag(self):
  62. os.system("echo 'testtest'>%s"
  63. % os.path.join(self.model_dir, download_model_file_name2))
  64. self.repo.push('add new file')
  65. self.repo.tag_and_push(self.revision2, 'Test revision')
  66. def test_snapshot_download_different_revision(self):
  67. self.prepare_repo_data()
  68. t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
  69. logger.info('First time stamp: %s' % t1)
  70. snapshot_path = snapshot_download(self.model_id, self.revision)
  71. assert os.path.exists(
  72. os.path.join(snapshot_path, download_model_file_name))
  73. self.add_new_file_and_tag()
  74. with tempfile.TemporaryDirectory() as temp_cache_dir:
  75. snapshot_path = snapshot_download(
  76. self.model_id,
  77. revision=self.revision,
  78. cache_dir=temp_cache_dir)
  79. assert os.path.exists(
  80. os.path.join(snapshot_path, download_model_file_name))
  81. assert not os.path.exists(
  82. os.path.join(snapshot_path, download_model_file_name2))
  83. with tempfile.TemporaryDirectory() as temp_cache_dir:
  84. snapshot_path = snapshot_download(
  85. self.model_id,
  86. revision=self.revision2,
  87. cache_dir=temp_cache_dir)
  88. assert os.path.exists(
  89. os.path.join(snapshot_path, download_model_file_name))
  90. assert os.path.exists(
  91. os.path.join(snapshot_path, download_model_file_name2))
  92. def test_file_download_different_revision(self):
  93. self.prepare_repo_data()
  94. t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
  95. logger.info('First time stamp: %s' % t1)
  96. file_path = model_file_download(self.model_id,
  97. download_model_file_name,
  98. self.revision)
  99. assert os.path.exists(file_path)
  100. self.add_new_file_and_tag()
  101. with tempfile.TemporaryDirectory() as temp_cache_dir:
  102. file_path = model_file_download(
  103. self.model_id,
  104. download_model_file_name,
  105. revision=self.revision,
  106. cache_dir=temp_cache_dir)
  107. assert os.path.exists(file_path)
  108. with self.assertRaises(NotExistError):
  109. model_file_download(
  110. self.model_id,
  111. download_model_file_name2,
  112. revision=self.revision,
  113. cache_dir=temp_cache_dir)
  114. with tempfile.TemporaryDirectory() as temp_cache_dir:
  115. file_path = model_file_download(
  116. self.model_id,
  117. download_model_file_name,
  118. revision=self.revision2,
  119. cache_dir=temp_cache_dir)
  120. print('Downloaded file path: %s' % file_path)
  121. assert os.path.exists(file_path)
  122. file_path = model_file_download(
  123. self.model_id,
  124. download_model_file_name2,
  125. revision=self.revision2,
  126. cache_dir=temp_cache_dir)
  127. assert os.path.exists(file_path)
  128. if __name__ == '__main__':
  129. unittest.main()