|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import tempfile
- import unittest
- import uuid
- from shutil import rmtree
-
- import requests
-
- from modelscope.hub.api import HubApi, ModelScopeConfig
- from modelscope.hub.constants import Licenses, ModelVisibility
- from modelscope.hub.file_download import model_file_download
- from modelscope.hub.repository import Repository
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.utils.constant import ModelFile
- from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
- TEST_MODEL_ORG)
-
- DEFAULT_GIT_PATH = 'git'
-
- download_model_file_name = 'test.bin'
-
-
- @unittest.skip(
- "Access token is always change, we can't login with same access token, so skip!"
- )
- class HubOperationTest(unittest.TestCase):
-
- def setUp(self):
- self.api = HubApi()
- # note this is temporary before official account management is ready
- self.api.login(TEST_ACCESS_TOKEN1)
- self.model_name = uuid.uuid4().hex
- self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
- self.api.create_model(
- model_id=self.model_id,
- visibility=ModelVisibility.PUBLIC,
- license=Licenses.APACHE_V2,
- chinese_name=TEST_MODEL_CHINESE_NAME,
- )
- temporary_dir = tempfile.mkdtemp()
- self.model_dir = os.path.join(temporary_dir, self.model_name)
- repo = Repository(self.model_dir, clone_from=self.model_id)
- os.system("echo 'testtest'>%s"
- % os.path.join(self.model_dir, download_model_file_name))
- repo.push('add model')
-
- def tearDown(self):
- self.api.delete_model(model_id=self.model_id)
-
- def test_model_repo_creation(self):
- # change to proper model names before use
- try:
- info = self.api.get_model(model_id=self.model_id)
- assert info['Name'] == self.model_name
- except KeyError as ke:
- if ke.args[0] == 'name':
- print(f'model {self.model_name} already exists, ignore')
- else:
- raise
-
- def test_download_single_file(self):
- downloaded_file = model_file_download(
- model_id=self.model_id, file_path=download_model_file_name)
- assert os.path.exists(downloaded_file)
- mdtime1 = os.path.getmtime(downloaded_file)
- # download again
- downloaded_file = model_file_download(
- model_id=self.model_id, file_path=download_model_file_name)
- mdtime2 = os.path.getmtime(downloaded_file)
- assert mdtime1 == mdtime2
-
- def test_snapshot_download(self):
- snapshot_path = snapshot_download(model_id=self.model_id)
- downloaded_file_path = os.path.join(snapshot_path,
- download_model_file_name)
- assert os.path.exists(downloaded_file_path)
- mdtime1 = os.path.getmtime(downloaded_file_path)
- # download again
- snapshot_path = snapshot_download(model_id=self.model_id)
- mdtime2 = os.path.getmtime(downloaded_file_path)
- assert mdtime1 == mdtime2
- model_file_download(
- model_id=self.model_id,
- file_path=download_model_file_name) # not add counter
-
- def test_download_public_without_login(self):
- rmtree(ModelScopeConfig.path_credential)
- snapshot_path = snapshot_download(model_id=self.model_id)
- downloaded_file_path = os.path.join(snapshot_path,
- download_model_file_name)
- assert os.path.exists(downloaded_file_path)
- temporary_dir = tempfile.mkdtemp()
- downloaded_file = model_file_download(
- model_id=self.model_id,
- file_path=download_model_file_name,
- cache_dir=temporary_dir)
- assert os.path.exists(downloaded_file)
- self.api.login(TEST_ACCESS_TOKEN1)
-
- def test_snapshot_delete_download_cache_file(self):
- snapshot_path = snapshot_download(model_id=self.model_id)
- downloaded_file_path = os.path.join(snapshot_path,
- download_model_file_name)
- assert os.path.exists(downloaded_file_path)
- os.remove(downloaded_file_path)
- # download again in cache
- file_download_path = model_file_download(
- model_id=self.model_id, file_path=ModelFile.README)
- assert os.path.exists(file_download_path)
- # deleted file need download again
- file_download_path = model_file_download(
- model_id=self.model_id, file_path=download_model_file_name)
- assert os.path.exists(file_download_path)
-
- def get_model_download_times(self):
- url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads'
- cookies = ModelScopeConfig.get_cookies()
- r = requests.get(url, cookies=cookies)
- if r.status_code == 200:
- return r.json()['Data']['Downloads']
- else:
- r.raise_for_status()
- return None
-
- def test_list_model(self):
- data = self.api.list_model(TEST_MODEL_ORG)
- assert len(data['Models']) >= 1
-
-
- if __name__ == '__main__':
- unittest.main()
|