# Copyright (c) Alibaba, Inc. and its affiliates. import os import shutil import tempfile import unittest import uuid from modelscope.hub.api import HubApi from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.errors import GitError, HTTPError, NotLoginException from modelscope.hub.repository import Repository from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential logger = get_logger() class HubUploadTest(unittest.TestCase): def setUp(self): logger.info('SetUp') self.api = HubApi() self.user = TEST_MODEL_ORG logger.info(self.user) self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload', uuid.uuid4().hex) logger.info('create %s' % self.create_model_name) temporary_dir = tempfile.mkdtemp() self.work_dir = temporary_dir self.model_dir = os.path.join(temporary_dir, self.create_model_name) self.finetune_path = os.path.join(self.work_dir, 'finetune_path') self.repo_path = os.path.join(self.work_dir, 'repo_path') os.mkdir(self.finetune_path) os.system("echo '{}'>%s" % os.path.join(self.finetune_path, ModelFile.CONFIGURATION)) def tearDown(self): logger.info('TearDown') shutil.rmtree(self.model_dir, ignore_errors=True) try: self.api.delete_model(model_id=self.create_model_name) except Exception: pass def test_upload_exits_repo_master(self): logger.info('basic test for upload!') self.api.login(TEST_ACCESS_TOKEN1) self.api.create_model( model_id=self.create_model_name, visibility=ModelVisibility.PUBLIC, license=Licenses.APACHE_V2) os.system("echo '111'>%s" % os.path.join(self.finetune_path, 'add1.py')) self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path) Repository(model_dir=self.repo_path, clone_from=self.create_model_name) assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) shutil.rmtree(self.repo_path, ignore_errors=True) os.system("echo '222'>%s" % os.path.join(self.finetune_path, 'add2.py')) self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path, revision='new_revision/version1') Repository( model_dir=self.repo_path, clone_from=self.create_model_name, revision='new_revision/version1') assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) shutil.rmtree(self.repo_path, ignore_errors=True) os.system("echo '333'>%s" % os.path.join(self.finetune_path, 'add3.py')) self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path, revision='new_revision/version2', commit_message='add add3.py') Repository( model_dir=self.repo_path, clone_from=self.create_model_name, revision='new_revision/version2') assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) assert os.path.exists(os.path.join(self.repo_path, 'add3.py')) shutil.rmtree(self.repo_path, ignore_errors=True) add4_path = os.path.join(self.finetune_path, 'temp') os.mkdir(add4_path) os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path, revision='new_revision/version1') Repository( model_dir=self.repo_path, clone_from=self.create_model_name, revision='new_revision/version1') assert os.path.exists(os.path.join(add4_path, 'add4.py')) shutil.rmtree(self.repo_path, ignore_errors=True) assert os.path.exists(os.path.join(self.finetune_path, 'add3.py')) os.remove(os.path.join(self.finetune_path, 'add3.py')) self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path, revision='new_revision/version1') Repository( model_dir=self.repo_path, clone_from=self.create_model_name, revision='new_revision/version1') assert not os.path.exists(os.path.join(self.repo_path, 'add3.py')) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_upload_non_exists_repo(self): logger.info('test upload non exists repo!') self.api.login(TEST_ACCESS_TOKEN1) os.system("echo '111'>%s" % os.path.join(self.finetune_path, 'add1.py')) self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path, revision='new_model_new_revision', visibility=ModelVisibility.PUBLIC, license=Licenses.APACHE_V2) Repository( model_dir=self.repo_path, clone_from=self.create_model_name, revision='new_model_new_revision') assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) shutil.rmtree(self.repo_path, ignore_errors=True) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_upload_without_token(self): logger.info('test upload without login!') self.api.login(TEST_ACCESS_TOKEN1) delete_credential() with self.assertRaises(NotLoginException): self.api.push_model( model_id=self.create_model_name, model_dir=self.finetune_path, visibility=ModelVisibility.PUBLIC, license=Licenses.APACHE_V2) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_upload_invalid_repo(self): logger.info('test upload to invalid repo!') self.api.login(TEST_ACCESS_TOKEN1) with self.assertRaises((HTTPError, GitError)): self.api.push_model( model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), model_dir=self.finetune_path, visibility=ModelVisibility.PUBLIC, license=Licenses.APACHE_V2) if __name__ == '__main__': unittest.main()