From 0d17eb5b395b0d1a74e1a10ad754843bd6dfc71b Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Tue, 28 Jun 2022 21:12:15 +0800 Subject: [PATCH] [to #42849800 #42822853 #42822836 #42822791 #42822717 #42820011]fix: bug test bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复测试bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9186775 * [to #42849800 #42822853 #42822836 #42822791 #42822717 #42820011]fix: test bugs --- modelscope/hub/api.py | 84 ++++++++++++++++------- modelscope/hub/errors.py | 4 ++ modelscope/hub/file_download.py | 16 +++-- modelscope/hub/git.py | 8 +++ modelscope/hub/repository.py | 12 ++-- modelscope/hub/snapshot_download.py | 16 ++--- modelscope/hub/utils/caching.py | 8 ++- modelscope/utils/hub.py | 5 +- tests/hub/test_hub_operation.py | 42 ++++++++++-- tests/hub/test_hub_private_files.py | 85 ++++++++++++++++++++++++ tests/hub/test_hub_private_repository.py | 9 ++- tests/hub/test_hub_repository.py | 24 ++----- 12 files changed, 235 insertions(+), 78 deletions(-) create mode 100644 tests/hub/test_hub_private_files.py diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index d102219b..e79bfd41 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -9,7 +9,7 @@ import requests from modelscope.utils.logger import get_logger from .constants import MODELSCOPE_URL_SCHEME -from .errors import NotExistError, is_ok, raise_on_error +from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error from .utils.utils import (get_endpoint, get_gitlab_domain, model_id_to_group_owner_name) @@ -61,17 +61,21 @@ class HubApi: return d['Data']['AccessToken'], cookies - def create_model(self, model_id: str, chinese_name: str, visibility: int, - license: str) -> str: + def create_model( + self, + model_id: str, + visibility: str, + license: str, + chinese_name: Optional[str] = None, + ) -> str: """ Create model repo at ModelScopeHub Args: model_id:(`str`): The model id - chinese_name(`str`): chinese name of the model - visibility(`int`): visibility of the model(1-private, 3-internal, 5-public) - license(`str`): license of the model, candidates can be found at: TBA - + visibility(`int`): visibility of the model(1-private, 5-public), default public. + license(`str`): license of the model, default none. + chinese_name(`str`, *optional*): chinese name of the model Returns: name of the model created @@ -79,6 +83,8 @@ class HubApi: model_id = {owner}/{name} """ + if model_id is None: + raise InvalidParameter('model_id is required!') cookies = ModelScopeConfig.get_cookies() if cookies is None: raise ValueError('Token does not exist, please login first.') @@ -151,11 +157,33 @@ class HubApi: else: r.raise_for_status() + def _check_cookie(self, + use_cookies: Union[bool, + CookieJar] = False) -> CookieJar: + cookies = None + if isinstance(use_cookies, CookieJar): + cookies = use_cookies + elif use_cookies: + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise ValueError('Token does not exist, please login first.') + return cookies + def get_model_branches_and_tags( self, model_id: str, + use_cookies: Union[bool, CookieJar] = False ) -> Tuple[List[str], List[str]]: - cookies = ModelScopeConfig.get_cookies() + """Get model branch and tags. + + Args: + model_id (str): The model id + use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will + will load cookie from local. Defaults to False. + Returns: + Tuple[List[str], List[str]]: _description_ + """ + cookies = self._check_cookie(use_cookies) path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' r = requests.get(path, cookies=cookies) @@ -169,23 +197,33 @@ class HubApi: ] if info['RevisionMap']['Tags'] else [] return branches, tags - def get_model_files( - self, - model_id: str, - revision: Optional[str] = 'master', - root: Optional[str] = None, - recursive: Optional[str] = False, - use_cookies: Union[bool, CookieJar] = False) -> List[dict]: + def get_model_files(self, + model_id: str, + revision: Optional[str] = 'master', + root: Optional[str] = None, + recursive: Optional[str] = False, + use_cookies: Union[bool, CookieJar] = False, + is_snapshot: Optional[bool] = True) -> List[dict]: + """List the models files. - cookies = None - if isinstance(use_cookies, CookieJar): - cookies = use_cookies - elif use_cookies: - cookies = ModelScopeConfig.get_cookies() - if cookies is None: - raise ValueError('Token does not exist, please login first.') + Args: + model_id (str): The model id + revision (Optional[str], optional): The branch or tag name. Defaults to 'master'. + root (Optional[str], optional): The root path. Defaults to None. + recursive (Optional[str], optional): Is recurive list files. Defaults to False. + use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will + will load cookie from local. Defaults to False. + is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False. - path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}' + Raises: + ValueError: If user_cookies is True, but no local cookie. + + Returns: + List[dict]: Model file list. + """ + path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % ( + self.endpoint, model_id, revision, recursive, is_snapshot) + cookies = self._check_cookie(use_cookies) if root is not None: path = path + f'&Root={root}' diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index d39036a0..9a19fdb5 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -10,6 +10,10 @@ class GitError(Exception): pass +class InvalidParameter(Exception): + pass + + def is_ok(rsp): """ Check the request is ok diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index b92bf89c..60aae3b6 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -7,6 +7,7 @@ import tempfile import time from functools import partial from hashlib import sha256 +from http.cookiejar import CookieJar from pathlib import Path from typing import BinaryIO, Dict, Optional, Union from uuid import uuid4 @@ -107,7 +108,9 @@ def model_file_download( _api = HubApi() headers = {'user-agent': http_user_agent(user_agent=user_agent, )} - branches, tags = _api.get_model_branches_and_tags(model_id) + cookies = ModelScopeConfig.get_cookies() + branches, tags = _api.get_model_branches_and_tags( + model_id, use_cookies=False if cookies is None else cookies) file_to_download_info = None is_commit_id = False if revision in branches or revision in tags: # The revision is version or tag, @@ -117,18 +120,19 @@ def model_file_download( model_id=model_id, revision=revision, recursive=True, - ) + use_cookies=False if cookies is None else cookies, + is_snapshot=False) for model_file in model_files: if model_file['Type'] == 'tree': continue if model_file['Path'] == file_path: - model_file['Branch'] = revision if cache.exists(model_file): return cache.get_file_by_info(model_file) else: file_to_download_info = model_file + break if file_to_download_info is None: raise NotExistError('The file path: %s not exist in: %s' % @@ -141,8 +145,6 @@ def model_file_download( return cached_file_path # the file is in cache. is_commit_id = True # we need to download again - # TODO: skip using JWT for authorization, use cookie instead - cookies = ModelScopeConfig.get_cookies() url_to_download = get_file_download_url(model_id, file_path, revision) file_to_download_info = { 'Path': file_path, @@ -202,7 +204,7 @@ def http_get_file( url: str, local_dir: str, file_name: str, - cookies: Dict[str, str], + cookies: CookieJar, headers: Optional[Dict[str, str]] = None, ): """ @@ -217,7 +219,7 @@ def http_get_file( local directory where the downloaded file stores file_name(`str`): name of the file stored in `local_dir` - cookies(`Dict[str, str]`): + cookies(`CookieJar`): cookies used to authentication the user, which is used for downloading private repos headers(`Optional[Dict[str, str]] = None`): http headers to carry necessary info when requesting the remote file diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py index 37f61814..54161f1c 100644 --- a/modelscope/hub/git.py +++ b/modelscope/hub/git.py @@ -70,6 +70,14 @@ class GitCommandWrapper(metaclass=Singleton): except GitError: return False + def git_lfs_install(self, repo_dir): + cmd = ['git', '-C', repo_dir, 'lfs', 'install'] + try: + self._run_git_command(*cmd) + return True + except GitError: + return False + def clone(self, repo_base_dir: str, token: str, diff --git a/modelscope/hub/repository.py b/modelscope/hub/repository.py index d9322144..37dec571 100644 --- a/modelscope/hub/repository.py +++ b/modelscope/hub/repository.py @@ -1,7 +1,7 @@ import os from typing import List, Optional -from modelscope.hub.errors import GitError +from modelscope.hub.errors import GitError, InvalidParameter from modelscope.utils.logger import get_logger from .api import ModelScopeConfig from .constants import MODELSCOPE_URL_SCHEME @@ -49,6 +49,8 @@ class Repository: git_wrapper = GitCommandWrapper() if not git_wrapper.is_lfs_installed(): logger.error('git lfs is not installed, please install.') + else: + git_wrapper.git_lfs_install(self.model_dir) # init repo lfs self.git_wrapper = GitCommandWrapper(git_path) os.makedirs(self.model_dir, exist_ok=True) @@ -74,8 +76,6 @@ class Repository: def push(self, commit_message: str, - files: List[str] = list(), - all_files: bool = False, branch: Optional[str] = 'master', force: bool = False): """Push local to remote, this method will do. @@ -86,8 +86,12 @@ class Repository: commit_message (str): commit message revision (Optional[str], optional): which branch to push. Defaults to 'master'. """ + if commit_message is None: + msg = 'commit_message must be provided!' + raise InvalidParameter(msg) url = self.git_wrapper.get_repo_remote_url(self.model_dir) - self.git_wrapper.add(self.model_dir, files, all_files) + self.git_wrapper.pull(self.model_dir) + self.git_wrapper.add(self.model_dir, all_files=True) self.git_wrapper.commit(self.model_dir, commit_message) self.git_wrapper.push( repo_dir=self.model_dir, diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 90d850f4..91463f76 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -20,8 +20,7 @@ def snapshot_download(model_id: str, revision: Optional[str] = 'master', cache_dir: Union[str, Path, None] = None, user_agent: Optional[Union[Dict, str]] = None, - local_files_only: Optional[bool] = False, - private: Optional[bool] = False) -> str: + local_files_only: Optional[bool] = False) -> str: """Download all files of a repo. Downloads a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which @@ -79,8 +78,10 @@ def snapshot_download(model_id: str, # make headers headers = {'user-agent': http_user_agent(user_agent=user_agent, )} _api = HubApi() + cookies = ModelScopeConfig.get_cookies() # get file list from model repo - branches, tags = _api.get_model_branches_and_tags(model_id) + branches, tags = _api.get_model_branches_and_tags( + model_id, use_cookies=False if cookies is None else cookies) if revision not in branches and revision not in tags: raise NotExistError('The specified branch or tag : %s not exist!' % revision) @@ -89,11 +90,8 @@ def snapshot_download(model_id: str, model_id=model_id, revision=revision, recursive=True, - use_cookies=private) - - cookies = None - if private: - cookies = ModelScopeConfig.get_cookies() + use_cookies=False if cookies is None else cookies, + is_snapshot=True) for model_file in model_files: if model_file['Type'] == 'tree': @@ -116,7 +114,7 @@ def snapshot_download(model_id: str, local_dir=tempfile.gettempdir(), file_name=model_file['Name'], headers=headers, - cookies=None if cookies is None else cookies.get_dict()) + cookies=cookies) # put file to cache cache.put_file( model_file, diff --git a/modelscope/hub/utils/caching.py b/modelscope/hub/utils/caching.py index ac258385..7675e49b 100644 --- a/modelscope/hub/utils/caching.py +++ b/modelscope/hub/utils/caching.py @@ -101,8 +101,9 @@ class FileSystemCache(object): Args: key (dict): The cache key. """ - self.cached_files.remove(key) - self.save_cached_files() + if key in self.cached_files: + self.cached_files.remove(key) + self.save_cached_files() def exists(self, key): for cache_file in self.cached_files: @@ -204,6 +205,7 @@ class ModelFileSystemCache(FileSystemCache): return orig_path else: self.remove_key(cached_file) + break return None @@ -230,6 +232,7 @@ class ModelFileSystemCache(FileSystemCache): cached_key['Revision'].startswith(key['Revision']) or key['Revision'].startswith(cached_key['Revision'])): is_exists = True + break file_path = os.path.join(self.cache_root_location, model_file_info['Path']) if is_exists: @@ -253,6 +256,7 @@ class ModelFileSystemCache(FileSystemCache): cached_file['Path']) if os.path.exists(file_path): os.remove(file_path) + break def put_file(self, model_file_info, model_file_location): """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index c427b7a3..3b7e80ef 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -31,9 +31,10 @@ def create_model_if_not_exist( else: api.create_model( model_id=model_id, - chinese_name=chinese_name, visibility=visibility, - license=license) + license=license, + chinese_name=chinese_name, + ) print(f'model {model_id} successfully created.') return True diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index 035b183e..d193ce32 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -3,6 +3,7 @@ import os import tempfile import unittest import uuid +from shutil import rmtree from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.constants import Licenses, ModelVisibility @@ -23,7 +24,6 @@ download_model_file_name = 'test.bin' class HubOperationTest(unittest.TestCase): def setUp(self): - self.old_cwd = os.getcwd() self.api = HubApi() # note this is temporary before official account management is ready self.api.login(USER_NAME, PASSWORD) @@ -31,19 +31,18 @@ class HubOperationTest(unittest.TestCase): self.model_id = '%s/%s' % (model_org, self.model_name) self.api.create_model( model_id=self.model_id, - chinese_name=model_chinese_name, visibility=ModelVisibility.PUBLIC, - license=Licenses.APACHE_V2) + license=Licenses.APACHE_V2, + chinese_name=model_chinese_name, + ) temporary_dir = tempfile.mkdtemp() self.model_dir = os.path.join(temporary_dir, self.model_name) repo = Repository(self.model_dir, clone_from=self.model_id) - os.chdir(self.model_dir) os.system("echo 'testtest'>%s" - % os.path.join(self.model_dir, 'test.bin')) - repo.push('add model', all_files=True) + % os.path.join(self.model_dir, download_model_file_name)) + repo.push('add model') def tearDown(self): - os.chdir(self.old_cwd) self.api.delete_model(model_id=self.model_id) def test_model_repo_creation(self): @@ -79,6 +78,35 @@ class HubOperationTest(unittest.TestCase): mdtime2 = os.path.getmtime(downloaded_file_path) assert mdtime1 == mdtime2 + def test_download_public_without_login(self): + rmtree(ModelScopeConfig.path_credential) + snapshot_path = snapshot_download(model_id=self.model_id) + downloaded_file_path = os.path.join(snapshot_path, + download_model_file_name) + assert os.path.exists(downloaded_file_path) + temporary_dir = tempfile.mkdtemp() + downloaded_file = model_file_download( + model_id=self.model_id, + file_path=download_model_file_name, + cache_dir=temporary_dir) + assert os.path.exists(downloaded_file) + self.api.login(USER_NAME, PASSWORD) + + def test_snapshot_delete_download_cache_file(self): + snapshot_path = snapshot_download(model_id=self.model_id) + downloaded_file_path = os.path.join(snapshot_path, + download_model_file_name) + assert os.path.exists(downloaded_file_path) + os.remove(downloaded_file_path) + # download again in cache + file_download_path = model_file_download( + model_id=self.model_id, file_path='README.md') + assert os.path.exists(file_download_path) + # deleted file need download again + file_download_path = model_file_download( + model_id=self.model_id, file_path=download_model_file_name) + assert os.path.exists(file_download_path) + if __name__ == '__main__': unittest.main() diff --git a/tests/hub/test_hub_private_files.py b/tests/hub/test_hub_private_files.py new file mode 100644 index 00000000..b9c71456 --- /dev/null +++ b/tests/hub/test_hub_private_files.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid + +from requests.exceptions import HTTPError + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import GitError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.constant import ModelFile + +USER_NAME = 'maasadmin' +PASSWORD = '12345678' +USER_NAME2 = 'sdkdev' + +model_chinese_name = '达摩卡通化模型' +model_org = 'unittest' + + +class HubPrivateFileDownloadTest(unittest.TestCase): + + def setUp(self): + self.old_cwd = os.getcwd() + self.api = HubApi() + # note this is temporary before official account management is ready + self.token, _ = self.api.login(USER_NAME, PASSWORD) + self.model_name = uuid.uuid4().hex + self.model_id = '%s/%s' % (model_org, self.model_name) + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PRIVATE, # 1-private, 5-public + license=Licenses.APACHE_V2, + chinese_name=model_chinese_name, + ) + + def tearDown(self): + os.chdir(self.old_cwd) + self.api.delete_model(model_id=self.model_id) + + def test_snapshot_download_private_model(self): + snapshot_path = snapshot_download(self.model_id) + assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) + + def test_snapshot_download_private_model_no_permission(self): + self.token, _ = self.api.login(USER_NAME2, PASSWORD) + with self.assertRaises(HTTPError): + snapshot_download(self.model_id) + self.api.login(USER_NAME, PASSWORD) + + def test_download_file_private_model(self): + file_path = model_file_download(self.model_id, ModelFile.README) + assert os.path.exists(file_path) + + def test_download_file_private_model_no_permission(self): + self.token, _ = self.api.login(USER_NAME2, PASSWORD) + with self.assertRaises(HTTPError): + model_file_download(self.model_id, ModelFile.README) + self.api.login(USER_NAME, PASSWORD) + + def test_snapshot_download_local_only(self): + with self.assertRaises(ValueError): + snapshot_download(self.model_id, local_files_only=True) + snapshot_path = snapshot_download(self.model_id) + assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) + snapshot_path = snapshot_download(self.model_id, local_files_only=True) + assert os.path.exists(snapshot_path) + + def test_file_download_local_only(self): + with self.assertRaises(ValueError): + model_file_download( + self.model_id, ModelFile.README, local_files_only=True) + file_path = model_file_download(self.model_id, ModelFile.README) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, ModelFile.README, local_files_only=True) + assert os.path.exists(file_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_private_repository.py b/tests/hub/test_hub_private_repository.py index b6e3536c..01a89586 100644 --- a/tests/hub/test_hub_private_repository.py +++ b/tests/hub/test_hub_private_repository.py @@ -5,6 +5,7 @@ import unittest import uuid from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.errors import GitError from modelscope.hub.repository import Repository @@ -16,9 +17,6 @@ model_chinese_name = '达摩卡通化模型' model_org = 'unittest' DEFAULT_GIT_PATH = 'git' -sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' -download_model_file_name = 'mnist-12.onnx' - class HubPrivateRepositoryTest(unittest.TestCase): @@ -31,9 +29,10 @@ class HubPrivateRepositoryTest(unittest.TestCase): self.model_id = '%s/%s' % (model_org, self.model_name) self.api.create_model( model_id=self.model_id, + visibility=ModelVisibility.PRIVATE, # 1-private, 5-public + license=Licenses.APACHE_V2, chinese_name=model_chinese_name, - visibility=1, # 1-private, 5-public - license='apache-2.0') + ) def tearDown(self): self.api.login(USER_NAME, PASSWORD) diff --git a/tests/hub/test_hub_repository.py b/tests/hub/test_hub_repository.py index 7b1cc751..99f63eca 100644 --- a/tests/hub/test_hub_repository.py +++ b/tests/hub/test_hub_repository.py @@ -2,7 +2,6 @@ import os import shutil import tempfile -import time import unittest import uuid from os.path import expanduser @@ -10,6 +9,7 @@ from os.path import expanduser from requests import delete from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.errors import NotExistError from modelscope.hub.file_download import model_file_download from modelscope.hub.repository import Repository @@ -55,9 +55,10 @@ class HubRepositoryTest(unittest.TestCase): self.model_id = '%s/%s' % (model_org, self.model_name) self.api.create_model( model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, # 1-private, 5-public + license=Licenses.APACHE_V2, chinese_name=model_chinese_name, - visibility=5, # 1-private, 5-public - license='apache-2.0') + ) temporary_dir = tempfile.mkdtemp() self.model_dir = os.path.join(temporary_dir, self.model_name) @@ -81,27 +82,12 @@ class HubRepositoryTest(unittest.TestCase): os.chdir(self.model_dir) os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) - repo.push('test', all_files=True) + repo.push('test') add1 = model_file_download(self.model_id, 'add1.py') assert os.path.exists(add1) add2 = model_file_download(self.model_id, 'add2.py') assert os.path.exists(add2) - def test_push_files(self): - repo = Repository(self.model_dir, clone_from=self.model_id) - assert os.path.exists(os.path.join(self.model_dir, 'README.md')) - os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) - os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) - os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) - repo.push('test', files=['add1.py', 'add2.py'], all_files=False) - add1 = model_file_download(self.model_id, 'add1.py') - assert os.path.exists(add1) - add2 = model_file_download(self.model_id, 'add2.py') - assert os.path.exists(add2) - with self.assertRaises(NotExistError) as cm: - model_file_download(self.model_id, 'add3.py') - print(cm.exception) - if __name__ == '__main__': unittest.main()