You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hub_revision_release_mode.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import tempfile
  4. import time
  5. import unittest
  6. import uuid
  7. from datetime import datetime
  8. from unittest import mock
  9. from modelscope import version
  10. from modelscope.hub.api import HubApi
  11. from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses,
  12. ModelVisibility)
  13. from modelscope.hub.errors import NotExistError
  14. from modelscope.hub.file_download import model_file_download
  15. from modelscope.hub.repository import Repository
  16. from modelscope.hub.snapshot_download import snapshot_download
  17. from modelscope.utils.logger import get_logger
  18. from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
  19. TEST_MODEL_ORG)
  20. logger = get_logger()
  21. logger.setLevel('DEBUG')
  22. download_model_file_name = 'test.bin'
  23. download_model_file_name2 = 'test2.bin'
  24. class HubRevisionTest(unittest.TestCase):
  25. def setUp(self):
  26. self.api = HubApi()
  27. self.api.login(TEST_ACCESS_TOKEN1)
  28. self.model_name = 'rvr-%s' % (uuid.uuid4().hex)
  29. self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
  30. self.revision = 'v0.1_test_revision'
  31. self.revision2 = 'v0.2_test_revision'
  32. self.api.create_model(
  33. model_id=self.model_id,
  34. visibility=ModelVisibility.PUBLIC,
  35. license=Licenses.APACHE_V2,
  36. chinese_name=TEST_MODEL_CHINESE_NAME,
  37. )
  38. names_to_remove = {MODELSCOPE_SDK_DEBUG}
  39. self.modified_environ = {
  40. k: v
  41. for k, v in os.environ.items() if k not in names_to_remove
  42. }
  43. def tearDown(self):
  44. self.api.delete_model(model_id=self.model_id)
  45. def prepare_repo_data(self):
  46. temporary_dir = tempfile.mkdtemp()
  47. self.model_dir = os.path.join(temporary_dir, self.model_name)
  48. self.repo = Repository(self.model_dir, clone_from=self.model_id)
  49. os.system("echo 'testtest'>%s"
  50. % os.path.join(self.model_dir, download_model_file_name))
  51. self.repo.push('add model')
  52. def prepare_repo_data_and_tag(self):
  53. self.prepare_repo_data()
  54. self.repo.tag_and_push(self.revision, 'Test revision')
  55. def add_new_file_and_tag_to_repo(self):
  56. os.system("echo 'testtest'>%s"
  57. % os.path.join(self.model_dir, download_model_file_name2))
  58. self.repo.push('add new file')
  59. self.repo.tag_and_push(self.revision2, 'Test revision')
  60. def add_new_file_and_branch_to_repo(self, branch_name):
  61. os.system("echo 'testtest'>%s"
  62. % os.path.join(self.model_dir, download_model_file_name2))
  63. self.repo.push('add new file', remote_branch=branch_name)
  64. def test_dev_mode_default_master(self):
  65. with mock.patch.dict(os.environ, self.modified_environ, clear=True):
  66. self.prepare_repo_data() # no tag, default get master
  67. with tempfile.TemporaryDirectory() as temp_cache_dir:
  68. snapshot_path = snapshot_download(
  69. self.model_id, cache_dir=temp_cache_dir)
  70. assert os.path.exists(
  71. os.path.join(snapshot_path, download_model_file_name))
  72. with tempfile.TemporaryDirectory() as temp_cache_dir:
  73. file_path = model_file_download(
  74. self.model_id,
  75. download_model_file_name,
  76. cache_dir=temp_cache_dir)
  77. assert os.path.exists(file_path)
  78. def test_dev_mode_specify_branch(self):
  79. with mock.patch.dict(os.environ, self.modified_environ, clear=True):
  80. self.prepare_repo_data() # no tag, default get master
  81. branch_name = 'test'
  82. self.add_new_file_and_branch_to_repo(branch_name)
  83. with tempfile.TemporaryDirectory() as temp_cache_dir:
  84. snapshot_path = snapshot_download(
  85. self.model_id,
  86. revision=branch_name,
  87. cache_dir=temp_cache_dir)
  88. assert os.path.exists(
  89. os.path.join(snapshot_path, download_model_file_name))
  90. with tempfile.TemporaryDirectory() as temp_cache_dir:
  91. file_path = model_file_download(
  92. self.model_id,
  93. download_model_file_name,
  94. revision=branch_name,
  95. cache_dir=temp_cache_dir)
  96. assert os.path.exists(file_path)
  97. def test_snapshot_download_revision(self):
  98. with mock.patch.dict(os.environ, self.modified_environ, clear=True):
  99. self.prepare_repo_data_and_tag()
  100. t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
  101. logger.info('First time: %s' % t1)
  102. time.sleep(10)
  103. self.add_new_file_and_tag_to_repo()
  104. t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
  105. logger.info('Second time: %s' % t2)
  106. # set
  107. release_datetime_backup = version.__release_datetime__
  108. logger.info('Origin __release_datetime__: %s'
  109. % version.__release_datetime__)
  110. try:
  111. logger.info('Setting __release_datetime__ to: %s' % t1)
  112. version.__release_datetime__ = t1
  113. with tempfile.TemporaryDirectory() as temp_cache_dir:
  114. snapshot_path = snapshot_download(
  115. self.model_id, cache_dir=temp_cache_dir)
  116. assert os.path.exists(
  117. os.path.join(snapshot_path, download_model_file_name))
  118. assert not os.path.exists(
  119. os.path.join(snapshot_path, download_model_file_name2))
  120. version.__release_datetime__ = t2
  121. logger.info('Setting __release_datetime__ to: %s' % t2)
  122. with tempfile.TemporaryDirectory() as temp_cache_dir:
  123. snapshot_path = snapshot_download(
  124. self.model_id, cache_dir=temp_cache_dir)
  125. assert os.path.exists(
  126. os.path.join(snapshot_path, download_model_file_name))
  127. assert os.path.exists(
  128. os.path.join(snapshot_path, download_model_file_name2))
  129. finally:
  130. version.__release_datetime__ = release_datetime_backup
  131. def test_snapshot_download_revision_user_set_revision(self):
  132. with mock.patch.dict(os.environ, self.modified_environ, clear=True):
  133. self.prepare_repo_data_and_tag()
  134. t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
  135. logger.info('First time: %s' % t1)
  136. time.sleep(10)
  137. self.add_new_file_and_tag_to_repo()
  138. t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
  139. logger.info('Secnod time: %s' % t2)
  140. # set
  141. release_datetime_backup = version.__release_datetime__
  142. logger.info('Origin __release_datetime__: %s'
  143. % version.__release_datetime__)
  144. try:
  145. logger.info('Setting __release_datetime__ to: %s' % t1)
  146. version.__release_datetime__ = t1
  147. with tempfile.TemporaryDirectory() as temp_cache_dir:
  148. snapshot_path = snapshot_download(
  149. self.model_id,
  150. revision=self.revision,
  151. cache_dir=temp_cache_dir)
  152. assert os.path.exists(
  153. os.path.join(snapshot_path, download_model_file_name))
  154. assert not os.path.exists(
  155. os.path.join(snapshot_path, download_model_file_name2))
  156. with tempfile.TemporaryDirectory() as temp_cache_dir:
  157. snapshot_path = snapshot_download(
  158. self.model_id,
  159. revision=self.revision2,
  160. cache_dir=temp_cache_dir)
  161. assert os.path.exists(
  162. os.path.join(snapshot_path, download_model_file_name))
  163. assert os.path.exists(
  164. os.path.join(snapshot_path, download_model_file_name2))
  165. finally:
  166. version.__release_datetime__ = release_datetime_backup
  167. def test_file_download_revision(self):
  168. with mock.patch.dict(os.environ, self.modified_environ, clear=True):
  169. self.prepare_repo_data_and_tag()
  170. t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
  171. logger.info('First time stamp: %s' % t1)
  172. time.sleep(10)
  173. self.add_new_file_and_tag_to_repo()
  174. t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
  175. logger.info('Second time: %s' % t2)
  176. release_datetime_backup = version.__release_datetime__
  177. logger.info('Origin __release_datetime__: %s'
  178. % version.__release_datetime__)
  179. try:
  180. version.__release_datetime__ = t1
  181. logger.info('Setting __release_datetime__ to: %s' % t1)
  182. with tempfile.TemporaryDirectory() as temp_cache_dir:
  183. file_path = model_file_download(
  184. self.model_id,
  185. download_model_file_name,
  186. cache_dir=temp_cache_dir)
  187. assert os.path.exists(file_path)
  188. with self.assertRaises(NotExistError):
  189. model_file_download(
  190. self.model_id,
  191. download_model_file_name2,
  192. cache_dir=temp_cache_dir)
  193. version.__release_datetime__ = t2
  194. logger.info('Setting __release_datetime__ to: %s' % t2)
  195. with tempfile.TemporaryDirectory() as temp_cache_dir:
  196. file_path = model_file_download(
  197. self.model_id,
  198. download_model_file_name,
  199. cache_dir=temp_cache_dir)
  200. assert os.path.exists(file_path)
  201. file_path = model_file_download(
  202. self.model_id,
  203. download_model_file_name2,
  204. cache_dir=temp_cache_dir)
  205. assert os.path.exists(file_path)
  206. finally:
  207. version.__release_datetime__ = release_datetime_backup
  208. def test_file_download_revision_user_set_revision(self):
  209. with mock.patch.dict(os.environ, self.modified_environ, clear=True):
  210. self.prepare_repo_data_and_tag()
  211. t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
  212. logger.info('First time stamp: %s' % t1)
  213. time.sleep(10)
  214. self.add_new_file_and_tag_to_repo()
  215. t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
  216. logger.info('Second time: %s' % t2)
  217. release_datetime_backup = version.__release_datetime__
  218. logger.info('Origin __release_datetime__: %s'
  219. % version.__release_datetime__)
  220. try:
  221. version.__release_datetime__ = t1
  222. logger.info('Setting __release_datetime__ to: %s' % t1)
  223. with tempfile.TemporaryDirectory() as temp_cache_dir:
  224. file_path = model_file_download(
  225. self.model_id,
  226. download_model_file_name,
  227. revision=self.revision,
  228. cache_dir=temp_cache_dir)
  229. assert os.path.exists(file_path)
  230. with self.assertRaises(NotExistError):
  231. model_file_download(
  232. self.model_id,
  233. download_model_file_name2,
  234. revision=self.revision,
  235. cache_dir=temp_cache_dir)
  236. with tempfile.TemporaryDirectory() as temp_cache_dir:
  237. file_path = model_file_download(
  238. self.model_id,
  239. download_model_file_name,
  240. revision=self.revision2,
  241. cache_dir=temp_cache_dir)
  242. assert os.path.exists(file_path)
  243. file_path = model_file_download(
  244. self.model_id,
  245. download_model_file_name2,
  246. revision=self.revision2,
  247. cache_dir=temp_cache_dir)
  248. assert os.path.exists(file_path)
  249. finally:
  250. version.__release_datetime__ = release_datetime_backup
  251. if __name__ == '__main__':
  252. unittest.main()