mulin.lyh 3 years ago
parent
commit
0d17eb5b39
12 changed files with 235 additions and 78 deletions
  1. +61
    -23
      modelscope/hub/api.py
  2. +4
    -0
      modelscope/hub/errors.py
  3. +9
    -7
      modelscope/hub/file_download.py
  4. +8
    -0
      modelscope/hub/git.py
  5. +8
    -4
      modelscope/hub/repository.py
  6. +7
    -9
      modelscope/hub/snapshot_download.py
  7. +6
    -2
      modelscope/hub/utils/caching.py
  8. +3
    -2
      modelscope/utils/hub.py
  9. +35
    -7
      tests/hub/test_hub_operation.py
  10. +85
    -0
      tests/hub/test_hub_private_files.py
  11. +4
    -5
      tests/hub/test_hub_private_repository.py
  12. +5
    -19
      tests/hub/test_hub_repository.py

+ 61
- 23
modelscope/hub/api.py View File

@@ -9,7 +9,7 @@ import requests

from modelscope.utils.logger import get_logger
from .constants import MODELSCOPE_URL_SCHEME
from .errors import NotExistError, is_ok, raise_on_error
from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error
from .utils.utils import (get_endpoint, get_gitlab_domain,
model_id_to_group_owner_name)

@@ -61,17 +61,21 @@ class HubApi:

return d['Data']['AccessToken'], cookies

def create_model(self, model_id: str, chinese_name: str, visibility: int,
license: str) -> str:
def create_model(
self,
model_id: str,
visibility: str,
license: str,
chinese_name: Optional[str] = None,
) -> str:
"""
Create model repo at ModelScopeHub

Args:
model_id:(`str`): The model id
chinese_name(`str`): chinese name of the model
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
license(`str`): license of the model, candidates can be found at: TBA

visibility(`int`): visibility of the model(1-private, 5-public), default public.
license(`str`): license of the model, default none.
chinese_name(`str`, *optional*): chinese name of the model
Returns:
name of the model created

@@ -79,6 +83,8 @@ class HubApi:
model_id = {owner}/{name}
</Tip>
"""
if model_id is None:
raise InvalidParameter('model_id is required!')
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
@@ -151,11 +157,33 @@ class HubApi:
else:
r.raise_for_status()

def _check_cookie(self,
use_cookies: Union[bool,
CookieJar] = False) -> CookieJar:
cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
return cookies

def get_model_branches_and_tags(
self,
model_id: str,
use_cookies: Union[bool, CookieJar] = False
) -> Tuple[List[str], List[str]]:
cookies = ModelScopeConfig.get_cookies()
"""Get model branch and tags.

Args:
model_id (str): The model id
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
Returns:
Tuple[List[str], List[str]]: _description_
"""
cookies = self._check_cookie(use_cookies)

path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = requests.get(path, cookies=cookies)
@@ -169,23 +197,33 @@ class HubApi:
] if info['RevisionMap']['Tags'] else []
return branches, tags

def get_model_files(
self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False) -> List[dict]:
def get_model_files(self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False,
is_snapshot: Optional[bool] = True) -> List[dict]:
"""List the models files.

cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
Args:
model_id (str): The model id
revision (Optional[str], optional): The branch or tag name. Defaults to 'master'.
root (Optional[str], optional): The root path. Defaults to None.
recursive (Optional[str], optional): Is recurive list files. Defaults to False.
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False.

path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
Raises:
ValueError: If user_cookies is True, but no local cookie.

Returns:
List[dict]: Model file list.
"""
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % (
self.endpoint, model_id, revision, recursive, is_snapshot)
cookies = self._check_cookie(use_cookies)
if root is not None:
path = path + f'&Root={root}'



+ 4
- 0
modelscope/hub/errors.py View File

@@ -10,6 +10,10 @@ class GitError(Exception):
pass


class InvalidParameter(Exception):
pass


def is_ok(rsp):
""" Check the request is ok



+ 9
- 7
modelscope/hub/file_download.py View File

@@ -7,6 +7,7 @@ import tempfile
import time
from functools import partial
from hashlib import sha256
from http.cookiejar import CookieJar
from pathlib import Path
from typing import BinaryIO, Dict, Optional, Union
from uuid import uuid4
@@ -107,7 +108,9 @@ def model_file_download(

_api = HubApi()
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
branches, tags = _api.get_model_branches_and_tags(model_id)
cookies = ModelScopeConfig.get_cookies()
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
file_to_download_info = None
is_commit_id = False
if revision in branches or revision in tags: # The revision is version or tag,
@@ -117,18 +120,19 @@ def model_file_download(
model_id=model_id,
revision=revision,
recursive=True,
)
use_cookies=False if cookies is None else cookies,
is_snapshot=False)

for model_file in model_files:
if model_file['Type'] == 'tree':
continue

if model_file['Path'] == file_path:
model_file['Branch'] = revision
if cache.exists(model_file):
return cache.get_file_by_info(model_file)
else:
file_to_download_info = model_file
break

if file_to_download_info is None:
raise NotExistError('The file path: %s not exist in: %s' %
@@ -141,8 +145,6 @@ def model_file_download(
return cached_file_path # the file is in cache.
is_commit_id = True
# we need to download again
# TODO: skip using JWT for authorization, use cookie instead
cookies = ModelScopeConfig.get_cookies()
url_to_download = get_file_download_url(model_id, file_path, revision)
file_to_download_info = {
'Path': file_path,
@@ -202,7 +204,7 @@ def http_get_file(
url: str,
local_dir: str,
file_name: str,
cookies: Dict[str, str],
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
):
"""
@@ -217,7 +219,7 @@ def http_get_file(
local directory where the downloaded file stores
file_name(`str`):
name of the file stored in `local_dir`
cookies(`Dict[str, str]`):
cookies(`CookieJar`):
cookies used to authentication the user, which is used for downloading private repos
headers(`Optional[Dict[str, str]] = None`):
http headers to carry necessary info when requesting the remote file


+ 8
- 0
modelscope/hub/git.py View File

@@ -70,6 +70,14 @@ class GitCommandWrapper(metaclass=Singleton):
except GitError:
return False

def git_lfs_install(self, repo_dir):
cmd = ['git', '-C', repo_dir, 'lfs', 'install']
try:
self._run_git_command(*cmd)
return True
except GitError:
return False

def clone(self,
repo_base_dir: str,
token: str,


+ 8
- 4
modelscope/hub/repository.py View File

@@ -1,7 +1,7 @@
import os
from typing import List, Optional

from modelscope.hub.errors import GitError
from modelscope.hub.errors import GitError, InvalidParameter
from modelscope.utils.logger import get_logger
from .api import ModelScopeConfig
from .constants import MODELSCOPE_URL_SCHEME
@@ -49,6 +49,8 @@ class Repository:
git_wrapper = GitCommandWrapper()
if not git_wrapper.is_lfs_installed():
logger.error('git lfs is not installed, please install.')
else:
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs

self.git_wrapper = GitCommandWrapper(git_path)
os.makedirs(self.model_dir, exist_ok=True)
@@ -74,8 +76,6 @@ class Repository:

def push(self,
commit_message: str,
files: List[str] = list(),
all_files: bool = False,
branch: Optional[str] = 'master',
force: bool = False):
"""Push local to remote, this method will do.
@@ -86,8 +86,12 @@ class Repository:
commit_message (str): commit message
revision (Optional[str], optional): which branch to push. Defaults to 'master'.
"""
if commit_message is None:
msg = 'commit_message must be provided!'
raise InvalidParameter(msg)
url = self.git_wrapper.get_repo_remote_url(self.model_dir)
self.git_wrapper.add(self.model_dir, files, all_files)
self.git_wrapper.pull(self.model_dir)
self.git_wrapper.add(self.model_dir, all_files=True)
self.git_wrapper.commit(self.model_dir, commit_message)
self.git_wrapper.push(
repo_dir=self.model_dir,


+ 7
- 9
modelscope/hub/snapshot_download.py View File

@@ -20,8 +20,7 @@ def snapshot_download(model_id: str,
revision: Optional[str] = 'master',
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
private: Optional[bool] = False) -> str:
local_files_only: Optional[bool] = False) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which
@@ -79,8 +78,10 @@ def snapshot_download(model_id: str,
# make headers
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
_api = HubApi()
cookies = ModelScopeConfig.get_cookies()
# get file list from model repo
branches, tags = _api.get_model_branches_and_tags(model_id)
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
if revision not in branches and revision not in tags:
raise NotExistError('The specified branch or tag : %s not exist!'
% revision)
@@ -89,11 +90,8 @@ def snapshot_download(model_id: str,
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=private)

cookies = None
if private:
cookies = ModelScopeConfig.get_cookies()
use_cookies=False if cookies is None else cookies,
is_snapshot=True)

for model_file in model_files:
if model_file['Type'] == 'tree':
@@ -116,7 +114,7 @@ def snapshot_download(model_id: str,
local_dir=tempfile.gettempdir(),
file_name=model_file['Name'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
cookies=cookies)
# put file to cache
cache.put_file(
model_file,


+ 6
- 2
modelscope/hub/utils/caching.py View File

@@ -101,8 +101,9 @@ class FileSystemCache(object):
Args:
key (dict): The cache key.
"""
self.cached_files.remove(key)
self.save_cached_files()
if key in self.cached_files:
self.cached_files.remove(key)
self.save_cached_files()

def exists(self, key):
for cache_file in self.cached_files:
@@ -204,6 +205,7 @@ class ModelFileSystemCache(FileSystemCache):
return orig_path
else:
self.remove_key(cached_file)
break

return None

@@ -230,6 +232,7 @@ class ModelFileSystemCache(FileSystemCache):
cached_key['Revision'].startswith(key['Revision'])
or key['Revision'].startswith(cached_key['Revision'])):
is_exists = True
break
file_path = os.path.join(self.cache_root_location,
model_file_info['Path'])
if is_exists:
@@ -253,6 +256,7 @@ class ModelFileSystemCache(FileSystemCache):
cached_file['Path'])
if os.path.exists(file_path):
os.remove(file_path)
break

def put_file(self, model_file_info, model_file_location):
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.


+ 3
- 2
modelscope/utils/hub.py View File

@@ -31,9 +31,10 @@ def create_model_if_not_exist(
else:
api.create_model(
model_id=model_id,
chinese_name=chinese_name,
visibility=visibility,
license=license)
license=license,
chinese_name=chinese_name,
)
print(f'model {model_id} successfully created.')
return True



+ 35
- 7
tests/hub/test_hub_operation.py View File

@@ -3,6 +3,7 @@ import os
import tempfile
import unittest
import uuid
from shutil import rmtree

from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility
@@ -23,7 +24,6 @@ download_model_file_name = 'test.bin'
class HubOperationTest(unittest.TestCase):

def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi()
# note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD)
@@ -31,19 +31,18 @@ class HubOperationTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
chinese_name=model_chinese_name,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
)
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)
os.chdir(self.model_dir)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, 'test.bin'))
repo.push('add model', all_files=True)
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')

def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id)

def test_model_repo_creation(self):
@@ -79,6 +78,35 @@ class HubOperationTest(unittest.TestCase):
mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2

def test_download_public_without_login(self):
rmtree(ModelScopeConfig.path_credential)
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
temporary_dir = tempfile.mkdtemp()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
cache_dir=temporary_dir)
assert os.path.exists(downloaded_file)
self.api.login(USER_NAME, PASSWORD)

def test_snapshot_delete_download_cache_file(self):
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
os.remove(downloaded_file_path)
# download again in cache
file_download_path = model_file_download(
model_id=self.model_id, file_path='README.md')
assert os.path.exists(file_download_path)
# deleted file need download again
file_download_path = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
assert os.path.exists(file_download_path)


if __name__ == '__main__':
unittest.main()

+ 85
- 0
tests/hub/test_hub_private_files.py View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import unittest
import uuid

from requests.exceptions import HTTPError

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile

USER_NAME = 'maasadmin'
PASSWORD = '12345678'
USER_NAME2 = 'sdkdev'

model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'


class HubPrivateFileDownloadTest(unittest.TestCase):

def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi()
# note this is temporary before official account management is ready
self.token, _ = self.api.login(USER_NAME, PASSWORD)
self.model_name = uuid.uuid4().hex
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
)

def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id)

def test_snapshot_download_private_model(self):
snapshot_path = snapshot_download(self.model_id)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))

def test_snapshot_download_private_model_no_permission(self):
self.token, _ = self.api.login(USER_NAME2, PASSWORD)
with self.assertRaises(HTTPError):
snapshot_download(self.model_id)
self.api.login(USER_NAME, PASSWORD)

def test_download_file_private_model(self):
file_path = model_file_download(self.model_id, ModelFile.README)
assert os.path.exists(file_path)

def test_download_file_private_model_no_permission(self):
self.token, _ = self.api.login(USER_NAME2, PASSWORD)
with self.assertRaises(HTTPError):
model_file_download(self.model_id, ModelFile.README)
self.api.login(USER_NAME, PASSWORD)

def test_snapshot_download_local_only(self):
with self.assertRaises(ValueError):
snapshot_download(self.model_id, local_files_only=True)
snapshot_path = snapshot_download(self.model_id)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))
snapshot_path = snapshot_download(self.model_id, local_files_only=True)
assert os.path.exists(snapshot_path)

def test_file_download_local_only(self):
with self.assertRaises(ValueError):
model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
file_path = model_file_download(self.model_id, ModelFile.README)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
assert os.path.exists(file_path)


if __name__ == '__main__':
unittest.main()

+ 4
- 5
tests/hub/test_hub_private_repository.py View File

@@ -5,6 +5,7 @@ import unittest
import uuid

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError
from modelscope.hub.repository import Repository

@@ -16,9 +17,6 @@ model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'
DEFAULT_GIT_PATH = 'git'

sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
download_model_file_name = 'mnist-12.onnx'


class HubPrivateRepositoryTest(unittest.TestCase):

@@ -31,9 +29,10 @@ class HubPrivateRepositoryTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
visibility=1, # 1-private, 5-public
license='apache-2.0')
)

def tearDown(self):
self.api.login(USER_NAME, PASSWORD)


+ 5
- 19
tests/hub/test_hub_repository.py View File

@@ -2,7 +2,6 @@
import os
import shutil
import tempfile
import time
import unittest
import uuid
from os.path import expanduser
@@ -10,6 +9,7 @@ from os.path import expanduser
from requests import delete

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import NotExistError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
@@ -55,9 +55,10 @@ class HubRepositoryTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
)
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)

@@ -81,27 +82,12 @@ class HubRepositoryTest(unittest.TestCase):
os.chdir(self.model_dir)
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
repo.push('test', all_files=True)
repo.push('test')
add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2)

def test_push_files(self):
repo = Repository(self.model_dir, clone_from=self.model_id)
assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py'))
repo.push('test', files=['add1.py', 'add2.py'], all_files=False)
add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2)
with self.assertRaises(NotExistError) as cm:
model_file_download(self.model_id, 'add3.py')
print(cm.exception)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save