You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hub_upload.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import uuid
  7. from modelscope.hub.api import HubApi
  8. from modelscope.hub.constants import Licenses, ModelVisibility
  9. from modelscope.hub.errors import HTTPError, NotLoginException
  10. from modelscope.hub.repository import Repository
  11. from modelscope.utils.constant import ModelFile
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import test_level
  14. from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential
  15. logger = get_logger()
  16. class HubUploadTest(unittest.TestCase):
  17. def setUp(self):
  18. logger.info('SetUp')
  19. self.api = HubApi()
  20. self.user = TEST_MODEL_ORG
  21. logger.info(self.user)
  22. self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload',
  23. uuid.uuid4().hex)
  24. logger.info('create %s' % self.create_model_name)
  25. temporary_dir = tempfile.mkdtemp()
  26. self.work_dir = temporary_dir
  27. self.model_dir = os.path.join(temporary_dir, self.create_model_name)
  28. self.finetune_path = os.path.join(self.work_dir, 'finetune_path')
  29. self.repo_path = os.path.join(self.work_dir, 'repo_path')
  30. os.mkdir(self.finetune_path)
  31. os.system("echo '{}'>%s"
  32. % os.path.join(self.finetune_path, ModelFile.CONFIGURATION))
  33. def tearDown(self):
  34. logger.info('TearDown')
  35. shutil.rmtree(self.model_dir, ignore_errors=True)
  36. try:
  37. self.api.delete_model(model_id=self.create_model_name)
  38. except Exception:
  39. pass
  40. def test_upload_exits_repo_master(self):
  41. logger.info('basic test for upload!')
  42. self.api.login(TEST_ACCESS_TOKEN1)
  43. self.api.create_model(
  44. model_id=self.create_model_name,
  45. visibility=ModelVisibility.PUBLIC,
  46. license=Licenses.APACHE_V2)
  47. os.system("echo '111'>%s"
  48. % os.path.join(self.finetune_path, 'add1.py'))
  49. self.api.push_model(
  50. model_id=self.create_model_name, model_dir=self.finetune_path)
  51. Repository(model_dir=self.repo_path, clone_from=self.create_model_name)
  52. assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
  53. shutil.rmtree(self.repo_path, ignore_errors=True)
  54. os.system("echo '222'>%s"
  55. % os.path.join(self.finetune_path, 'add2.py'))
  56. self.api.push_model(
  57. model_id=self.create_model_name,
  58. model_dir=self.finetune_path,
  59. revision='new_revision/version1')
  60. Repository(
  61. model_dir=self.repo_path,
  62. clone_from=self.create_model_name,
  63. revision='new_revision/version1')
  64. assert os.path.exists(os.path.join(self.repo_path, 'add2.py'))
  65. shutil.rmtree(self.repo_path, ignore_errors=True)
  66. os.system("echo '333'>%s"
  67. % os.path.join(self.finetune_path, 'add3.py'))
  68. self.api.push_model(
  69. model_id=self.create_model_name,
  70. model_dir=self.finetune_path,
  71. revision='new_revision/version2',
  72. commit_message='add add3.py')
  73. Repository(
  74. model_dir=self.repo_path,
  75. clone_from=self.create_model_name,
  76. revision='new_revision/version2')
  77. assert os.path.exists(os.path.join(self.repo_path, 'add2.py'))
  78. assert os.path.exists(os.path.join(self.repo_path, 'add3.py'))
  79. shutil.rmtree(self.repo_path, ignore_errors=True)
  80. add4_path = os.path.join(self.finetune_path, 'temp')
  81. os.mkdir(add4_path)
  82. os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py'))
  83. self.api.push_model(
  84. model_id=self.create_model_name,
  85. model_dir=self.finetune_path,
  86. revision='new_revision/version1')
  87. Repository(
  88. model_dir=self.repo_path,
  89. clone_from=self.create_model_name,
  90. revision='new_revision/version1')
  91. assert os.path.exists(os.path.join(add4_path, 'add4.py'))
  92. shutil.rmtree(self.repo_path, ignore_errors=True)
  93. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  94. def test_upload_non_exists_repo(self):
  95. logger.info('test upload non exists repo!')
  96. self.api.login(TEST_ACCESS_TOKEN1)
  97. os.system("echo '111'>%s"
  98. % os.path.join(self.finetune_path, 'add1.py'))
  99. self.api.push_model(
  100. model_id=self.create_model_name,
  101. model_dir=self.finetune_path,
  102. revision='new_model_new_revision',
  103. visibility=ModelVisibility.PUBLIC,
  104. license=Licenses.APACHE_V2)
  105. Repository(
  106. model_dir=self.repo_path,
  107. clone_from=self.create_model_name,
  108. revision='new_model_new_revision')
  109. assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
  110. shutil.rmtree(self.repo_path, ignore_errors=True)
  111. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  112. def test_upload_without_token(self):
  113. logger.info('test upload without login!')
  114. self.api.login(TEST_ACCESS_TOKEN1)
  115. delete_credential()
  116. with self.assertRaises(NotLoginException):
  117. self.api.push_model(
  118. model_id=self.create_model_name,
  119. model_dir=self.finetune_path,
  120. visibility=ModelVisibility.PUBLIC,
  121. license=Licenses.APACHE_V2)
  122. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  123. def test_upload_invalid_repo(self):
  124. logger.info('test upload to invalid repo!')
  125. self.api.login(TEST_ACCESS_TOKEN1)
  126. with self.assertRaises(HTTPError):
  127. self.api.push_model(
  128. model_id='%s/%s' % ('speech_tts', 'invalid_model_test'),
  129. model_dir=self.finetune_path,
  130. visibility=ModelVisibility.PUBLIC,
  131. license=Licenses.APACHE_V2)
  132. if __name__ == '__main__':
  133. unittest.main()