diff --git a/modelscope/hub/__init__.py b/modelscope/hub/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py
new file mode 100644
index 00000000..104eafbd
--- /dev/null
+++ b/modelscope/hub/api.py
@@ -0,0 +1,265 @@
+import imp
+import os
+import pickle
+import subprocess
+from http.cookiejar import CookieJar
+from os.path import expanduser
+from typing import List, Optional, Tuple, Union
+
+import requests
+
+from modelscope.utils.logger import get_logger
+from .constants import LOGGER_NAME
+from .errors import NotExistError, is_ok, raise_on_error
+from .utils.utils import get_endpoint, model_id_to_group_owner_name
+
+logger = get_logger()
+
+
+class HubApi:
+
+ def __init__(self, endpoint=None):
+ self.endpoint = endpoint if endpoint is not None else get_endpoint()
+
+ def login(
+ self,
+ user_name: str,
+ password: str,
+ ) -> tuple():
+ """
+ Login with username and password
+
+ Args:
+ username(`str`): user name on modelscope
+ password(`str`): password
+
+ Returns:
+ cookies: to authenticate yourself to ModelScope open-api
+ gitlab token: to access private repos
+
+
+ You only have to login once within 30 days.
+
+
+ TODO: handle cookies expire
+
+ """
+ path = f'{self.endpoint}/api/v1/login'
+ r = requests.post(
+ path, json={
+ 'username': user_name,
+ 'password': password
+ })
+ r.raise_for_status()
+ d = r.json()
+ raise_on_error(d)
+
+ token = d['Data']['AccessToken']
+ cookies = r.cookies
+
+ # save token and cookie
+ ModelScopeConfig.save_token(token)
+ ModelScopeConfig.save_cookies(cookies)
+ ModelScopeConfig.write_to_git_credential(user_name, password)
+
+ return d['Data']['AccessToken'], cookies
+
+ def create_model(self, model_id: str, chinese_name: str, visibility: int,
+ license: str) -> str:
+ """
+ Create model repo at ModelScopeHub
+
+ Args:
+ model_id:(`str`): The model id
+ chinese_name(`str`): chinese name of the model
+ visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
+ license(`str`): license of the model, candidates can be found at: TBA
+
+ Returns:
+ name of the model created
+
+
+ model_id = {owner}/{name}
+
+ """
+ cookies = ModelScopeConfig.get_cookies()
+ if cookies is None:
+ raise ValueError('Token does not exist, please login first.')
+
+ path = f'{self.endpoint}/api/v1/models'
+ owner_or_group, name = model_id_to_group_owner_name(model_id)
+ r = requests.post(
+ path,
+ json={
+ 'Path': owner_or_group,
+ 'Name': name,
+ 'ChineseName': chinese_name,
+ 'Visibility': visibility,
+ 'License': license
+ },
+ cookies=cookies)
+ r.raise_for_status()
+ raise_on_error(r.json())
+ d = r.json()
+ return d['Data']['Name']
+
+ def delete_model(self, model_id):
+ """_summary_
+
+ Args:
+ model_id (str): The model id.
+
+ model_id = {owner}/{name}
+
+ """
+ cookies = ModelScopeConfig.get_cookies()
+ path = f'{self.endpoint}/api/v1/models/{model_id}'
+
+ r = requests.delete(path, cookies=cookies)
+ r.raise_for_status()
+ raise_on_error(r.json())
+
+ def get_model_url(self, model_id):
+ return f'{self.endpoint}/api/v1/models/{model_id}.git'
+
+ def get_model(
+ self,
+ model_id: str,
+ revision: str = 'master',
+ ) -> str:
+ """
+ Get model information at modelscope_hub
+
+ Args:
+ model_id(`str`): The model id.
+ revision(`str`): revision of model
+ Returns:
+ The model details information.
+ Raises:
+ NotExistError: If the model is not exist, will throw NotExistError
+
+ model_id = {owner}/{name}
+
+ """
+ cookies = ModelScopeConfig.get_cookies()
+ owner_or_group, name = model_id_to_group_owner_name(model_id)
+ path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?{revision}'
+
+ r = requests.get(path, cookies=cookies)
+ if r.status_code == 200:
+ if is_ok(r.json()):
+ return r.json()['Data']
+ else:
+ raise NotExistError(r.json()['Message'])
+ else:
+ r.raise_for_status()
+
+ def get_model_branches_and_tags(
+ self,
+ model_id: str,
+ ) -> Tuple[List[str], List[str]]:
+ cookies = ModelScopeConfig.get_cookies()
+
+ path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
+ r = requests.get(path, cookies=cookies)
+ r.raise_for_status()
+ d = r.json()
+ raise_on_error(d)
+ info = d['Data']
+ branches = [x['Revision'] for x in info['RevisionMap']['Branches']
+ ] if info['RevisionMap']['Branches'] else []
+ tags = [x['Revision'] for x in info['RevisionMap']['Tags']
+ ] if info['RevisionMap']['Tags'] else []
+ return branches, tags
+
+ def get_model_files(
+ self,
+ model_id: str,
+ revision: Optional[str] = 'master',
+ root: Optional[str] = None,
+ recursive: Optional[str] = False,
+ use_cookies: Union[bool, CookieJar] = False) -> List[dict]:
+
+ cookies = None
+ if isinstance(use_cookies, CookieJar):
+ cookies = use_cookies
+ elif use_cookies:
+ cookies = ModelScopeConfig.get_cookies()
+ if cookies is None:
+ raise ValueError('Token does not exist, please login first.')
+
+ path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
+ if root is not None:
+ path = path + f'&Root={root}'
+
+ r = requests.get(path, cookies=cookies)
+
+ r.raise_for_status()
+ d = r.json()
+ raise_on_error(d)
+
+ files = []
+ for file in d['Data']['Files']:
+ if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
+ continue
+
+ files.append(file)
+ return files
+
+
+class ModelScopeConfig:
+ path_credential = expanduser('~/.modelscope/credentials')
+ os.makedirs(path_credential, exist_ok=True)
+
+ @classmethod
+ def save_cookies(cls, cookies: CookieJar):
+ with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f:
+ pickle.dump(cookies, f)
+
+ @classmethod
+ def get_cookies(cls):
+ try:
+ with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f:
+ return pickle.load(f)
+ except FileNotFoundError:
+ logger.warn("Auth token does not exist, you'll get authentication \
+ error when downloading private model files. Please login first"
+ )
+
+ @classmethod
+ def save_token(cls, token: str):
+ with open(os.path.join(cls.path_credential, 'token'), 'w+') as f:
+ f.write(token)
+
+ @classmethod
+ def get_token(cls) -> Optional[str]:
+ """
+ Get token or None if not existent.
+
+ Returns:
+ `str` or `None`: The token, `None` if it doesn't exist.
+
+ """
+ token = None
+ try:
+ with open(os.path.join(cls.path_credential, 'token'), 'r') as f:
+ token = f.read()
+ except FileNotFoundError:
+ pass
+ return token
+
+ @staticmethod
+ def write_to_git_credential(username: str, password: str):
+ with subprocess.Popen(
+ 'git credential-store store'.split(),
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ ) as process:
+ input_username = f'username={username.lower()}'
+ input_password = f'password={password}'
+
+ process.stdin.write(
+ f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n'
+ .encode('utf-8'))
+ process.stdin.flush()
diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py
new file mode 100644
index 00000000..a38f9afb
--- /dev/null
+++ b/modelscope/hub/constants.py
@@ -0,0 +1,8 @@
+MODELSCOPE_URL_SCHEME = 'http://'
+DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330'
+DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102'
+
+DEFAULT_MODELSCOPE_GROUP = 'damo'
+MODEL_ID_SEPARATOR = '/'
+
+LOGGER_NAME = 'ModelScopeHub'
diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py
new file mode 100644
index 00000000..13ea709f
--- /dev/null
+++ b/modelscope/hub/errors.py
@@ -0,0 +1,30 @@
+class NotExistError(Exception):
+ pass
+
+
+class RequestError(Exception):
+ pass
+
+
+def is_ok(rsp):
+ """ Check the request is ok
+
+ Args:
+ rsp (_type_): The request response body
+ Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission',
+ 'RequestId': '', 'Success': False}
+ Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True}
+ """
+ return rsp['Code'] == 200 and rsp['Success']
+
+
+def raise_on_error(rsp):
+ """If response error, raise exception
+
+ Args:
+ rsp (_type_): The server response
+ """
+ if rsp['Code'] == 200 and rsp['Success']:
+ return True
+ else:
+ raise RequestError(rsp['Message'])
diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py
new file mode 100644
index 00000000..e5c64f1c
--- /dev/null
+++ b/modelscope/hub/file_download.py
@@ -0,0 +1,254 @@
+import copy
+import fnmatch
+import logging
+import os
+import sys
+import tempfile
+import time
+from functools import partial
+from hashlib import sha256
+from pathlib import Path
+from typing import BinaryIO, Dict, Optional, Union
+from uuid import uuid4
+
+import json
+import requests
+from filelock import FileLock
+from requests.exceptions import HTTPError
+from tqdm import tqdm
+
+from modelscope import __version__
+from modelscope.utils.logger import get_logger
+from .api import HubApi, ModelScopeConfig
+from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME,
+ MODEL_ID_SEPARATOR)
+from .errors import NotExistError, RequestError, raise_on_error
+from .utils.caching import ModelFileSystemCache
+from .utils.utils import (get_cache_dir, get_endpoint,
+ model_id_to_group_owner_name)
+
+SESSION_ID = uuid4().hex
+logger = get_logger()
+
+
+def model_file_download(
+ model_id: str,
+ file_path: str,
+ revision: Optional[str] = 'master',
+ cache_dir: Optional[str] = None,
+ user_agent: Union[Dict, str, None] = None,
+ local_files_only: Optional[bool] = False,
+) -> Optional[str]: # pragma: no cover
+ """
+ Download from a given URL and cache it if it's not already present in the
+ local cache.
+
+ Given a URL, this function looks for the corresponding file in the local
+ cache. If it's not there, download it. Then return the path to the cached
+ file.
+
+ Args:
+ model_id (`str`):
+ The model to whom the file to be downloaded belongs.
+ file_path(`str`):
+ Path of the file to be downloaded, relative to the root of model repo
+ revision(`str`, *optional*):
+ revision of the model file to be downloaded.
+ Can be any of a branch, tag or commit hash, default to `master`
+ cache_dir (`str`, `Path`, *optional*):
+ Path to the folder where cached files are stored.
+ user_agent (`dict`, `str`, *optional*):
+ The user-agent info in the form of a dictionary or a string.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, avoid downloading the file and return the path to the
+ local cached file if it exists.
+ if `False`, download the file anyway even it exists
+
+ Returns:
+ Local path (string) of file or if networking is off, last version of
+ file cached on disk.
+
+
+
+ Raises the following errors:
+
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
+ if `use_auth_token=True` and the token cannot be found.
+ - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
+ if ETag cannot be determined.
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+ if some parameter value is invalid
+
+
+ """
+ if cache_dir is None:
+ cache_dir = get_cache_dir()
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ group_or_owner, name = model_id_to_group_owner_name(model_id)
+
+ cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
+
+ # if local_files_only is `True` and the file already exists in cached_path
+ # return the cached path
+ if local_files_only:
+ cached_file_path = cache.get_file_by_path(file_path)
+ if cached_file_path is not None:
+ logger.warning(
+ "File exists in local cache, but we're not sure it's up to date"
+ )
+ return cached_file_path
+ else:
+ raise ValueError(
+ 'Cannot find the requested files in the cached path and outgoing'
+ ' traffic has been disabled. To enable model look-ups and downloads'
+ " online, set 'local_files_only' to False.")
+
+ _api = HubApi()
+ headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
+ branches, tags = _api.get_model_branches_and_tags(model_id)
+ file_to_download_info = None
+ is_commit_id = False
+ if revision in branches or revision in tags: # The revision is version or tag,
+ # we need to confirm the version is up to date
+ # we need to get the file list to check if the lateast version is cached, if so return, otherwise download
+ model_files = _api.get_model_files(
+ model_id=model_id,
+ revision=revision,
+ recursive=True,
+ )
+
+ for model_file in model_files:
+ if model_file['Type'] == 'tree':
+ continue
+
+ if model_file['Path'] == file_path:
+ model_file['Branch'] = revision
+ if cache.exists(model_file):
+ return cache.get_file_by_info(model_file)
+ else:
+ file_to_download_info = model_file
+
+ if file_to_download_info is None:
+ raise NotExistError('The file path: %s not exist in: %s' %
+ (file_path, model_id))
+ else: # the revision is commit id.
+ cached_file_path = cache.get_file_by_path_and_commit_id(
+ file_path, revision)
+ if cached_file_path is not None:
+ logger.info('The specified file is in cache, skip downloading!')
+ return cached_file_path # the file is in cache.
+ is_commit_id = True
+ # we need to download again
+ # TODO: skip using JWT for authorization, use cookie instead
+ cookies = ModelScopeConfig.get_cookies()
+ url_to_download = get_file_download_url(model_id, file_path, revision)
+ file_to_download_info = {
+ 'Path': file_path,
+ 'Revision':
+ revision if is_commit_id else file_to_download_info['Revision']
+ }
+ # Prevent parallel downloads of the same file with a lock.
+ lock_path = cache.get_root_location() + '.lock'
+
+ with FileLock(lock_path):
+ temp_file_name = next(tempfile._get_candidate_names())
+ http_get_file(
+ url_to_download,
+ cache_dir,
+ temp_file_name,
+ headers=headers,
+ cookies=None if cookies is None else cookies.get_dict())
+ return cache.put_file(file_to_download_info,
+ os.path.join(cache_dir, temp_file_name))
+
+
+def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
+ """Formats a user-agent string with basic info about a request.
+
+ Args:
+ user_agent (`str`, `dict`, *optional*):
+ The user agent info in the form of a dictionary or a single string.
+
+ Returns:
+ The formatted user-agent string.
+ """
+ ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}'
+
+ if isinstance(user_agent, dict):
+ ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
+ elif isinstance(user_agent, str):
+ ua = user_agent
+ return ua
+
+
+def get_file_download_url(model_id: str, file_path: str, revision: str):
+ """
+ Format file download url according to `model_id`, `revision` and `file_path`.
+ e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
+ the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
+ """
+ download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
+ return download_url_template.format(
+ endpoint=get_endpoint(),
+ model_id=model_id,
+ revision=revision,
+ file_path=file_path,
+ )
+
+
+def http_get_file(
+ url: str,
+ local_dir: str,
+ file_name: str,
+ cookies: Dict[str, str],
+ headers: Optional[Dict[str, str]] = None,
+):
+ """
+ Download remote file. Do not gobble up errors.
+ This method is only used by snapshot_download, since the behavior is quite different with single file download
+ TODO: consolidate with http_get_file() to avoild duplicate code
+
+ Args:
+ url(`str`):
+ actual download url of the file
+ local_dir(`str`):
+ local directory where the downloaded file stores
+ file_name(`str`):
+ name of the file stored in `local_dir`
+ cookies(`Dict[str, str]`):
+ cookies used to authentication the user, which is used for downloading private repos
+ headers(`Optional[Dict[str, str]] = None`):
+ http headers to carry necessary info when requesting the remote file
+
+ """
+ temp_file_manager = partial(
+ tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
+
+ with temp_file_manager() as temp_file:
+ logger.info('downloading %s to %s', url, temp_file.name)
+ headers = copy.deepcopy(headers)
+
+ r = requests.get(url, stream=True, headers=headers, cookies=cookies)
+ r.raise_for_status()
+
+ content_length = r.headers.get('Content-Length')
+ total = int(content_length) if content_length is not None else None
+
+ progress = tqdm(
+ unit='B',
+ unit_scale=True,
+ unit_divisor=1024,
+ total=total,
+ initial=0,
+ desc='Downloading',
+ )
+ for chunk in r.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ progress.update(len(chunk))
+ temp_file.write(chunk)
+ progress.close()
+
+ logger.info('storing %s in cache at %s', url, local_dir)
+ os.replace(temp_file.name, os.path.join(local_dir, file_name))
diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py
new file mode 100644
index 00000000..5f079105
--- /dev/null
+++ b/modelscope/hub/git.py
@@ -0,0 +1,82 @@
+from threading import local
+from tkinter.messagebox import NO
+from typing import Union
+
+from modelscope.utils.logger import get_logger
+from .constants import LOGGER_NAME
+from .utils._subprocess import run_subprocess
+
+logger = get_logger
+
+
+def git_clone(
+ local_dir: str,
+ repo_url: str,
+):
+ # TODO: use "git clone" or "git lfs clone" according to git version
+ # TODO: print stderr when subprocess fails
+ run_subprocess(
+ f'git clone {repo_url}'.split(),
+ local_dir,
+ True,
+ )
+
+
+def git_checkout(
+ local_dir: str,
+ revsion: str,
+):
+ run_subprocess(f'git checkout {revsion}'.split(), local_dir)
+
+
+def git_add(local_dir: str, ):
+ run_subprocess(
+ 'git add .'.split(),
+ local_dir,
+ True,
+ )
+
+
+def git_commit(local_dir: str, commit_message: str):
+ run_subprocess(
+ 'git commit -v -m'.split() + [commit_message],
+ local_dir,
+ True,
+ )
+
+
+def git_push(local_dir: str, branch: str):
+ # check current branch
+ cur_branch = git_current_branch(local_dir)
+ if cur_branch != branch:
+ logger.error(
+ "You're trying to push to a different branch, please double check")
+ return
+
+ run_subprocess(
+ f'git push origin {branch}'.split(),
+ local_dir,
+ True,
+ )
+
+
+def git_current_branch(local_dir: str) -> Union[str, None]:
+ """
+ Get current branch name
+
+ Args:
+ local_dir(`str`): local model repo directory
+
+ Returns
+ branch name you're currently on
+ """
+ try:
+ process = run_subprocess(
+ 'git rev-parse --abbrev-ref HEAD'.split(),
+ local_dir,
+ True,
+ )
+
+ return str(process.stdout).strip()
+ except Exception as e:
+ raise e
diff --git a/modelscope/hub/repository.py b/modelscope/hub/repository.py
new file mode 100644
index 00000000..6367f903
--- /dev/null
+++ b/modelscope/hub/repository.py
@@ -0,0 +1,173 @@
+import os
+import subprocess
+from pathlib import Path
+from typing import Optional, Union
+
+from modelscope.utils.logger import get_logger
+from .api import ModelScopeConfig
+from .constants import MODELSCOPE_URL_SCHEME
+from .git import git_add, git_checkout, git_clone, git_commit, git_push
+from .utils._subprocess import run_subprocess
+from .utils.utils import get_gitlab_domain
+
+logger = get_logger()
+
+
+class Repository:
+
+ def __init__(
+ self,
+ local_dir: str,
+ clone_from: Optional[str] = None,
+ auth_token: Optional[str] = None,
+ private: Optional[bool] = False,
+ revision: Optional[str] = 'master',
+ ):
+ """
+ Instantiate a Repository object by cloning the remote ModelScopeHub repo
+ Args:
+ local_dir(`str`):
+ local directory to store the model files
+ clone_from(`Optional[str] = None`):
+ model id in ModelScope-hub from which git clone
+ You should ignore this parameter when `local_dir` is already a git repo
+ auth_token(`Optional[str]`):
+ token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
+ as the token is already saved when you login the first time
+ private(`Optional[bool]`):
+ whether the model is private, default to False
+ revision(`Optional[str]`):
+ revision of the model you want to clone from. Can be any of a branch, tag or commit hash
+ """
+ logger.info('Instantiating Repository object...')
+
+ # Create local directory if not exist
+ os.makedirs(local_dir, exist_ok=True)
+ self.local_dir = os.path.join(os.getcwd(), local_dir)
+
+ self.private = private
+
+ # Check git and git-lfs installation
+ self.check_git_versions()
+
+ # Retrieve auth token
+ if not private and isinstance(auth_token, str):
+ logger.warning(
+ 'cloning a public repo with a token, which will be ignored')
+ self.token = None
+ else:
+ if isinstance(auth_token, str):
+ self.token = auth_token
+ else:
+ self.token = ModelScopeConfig.get_token()
+
+ if self.token is None:
+ raise EnvironmentError(
+ 'Token does not exist, the clone will fail for private repo.'
+ 'Please login first.')
+
+ # git clone
+ if clone_from is not None:
+ self.model_id = clone_from
+ logger.info('cloning model repo to %s ...', self.local_dir)
+ git_clone(self.local_dir, self.get_repo_url())
+ else:
+ if is_git_repo(self.local_dir):
+ logger.debug('[Repository] is a valid git repo')
+ else:
+ raise ValueError(
+ 'If not specifying `clone_from`, you need to pass Repository a'
+ ' valid git clone.')
+
+ # git checkout
+ if isinstance(revision, str) and revision != 'master':
+ git_checkout(revision)
+
+ def push_to_hub(self,
+ commit_message: str,
+ revision: Optional[str] = 'master'):
+ """
+ Push changes changes to hub
+
+ Args:
+ commit_message(`str`):
+ commit message describing the changes, it's mandatory
+ revision(`Optional[str]`):
+ remote branch you want to push to, default to `master`
+
+
+ The function complains when local and remote branch are different, please be careful
+
+
+ """
+ git_add(self.local_dir)
+ git_commit(self.local_dir, commit_message)
+
+ logger.info('Pushing changes to repo...')
+ git_push(self.local_dir, revision)
+
+ # TODO: if git push fails, how to retry?
+
+ def check_git_versions(self):
+ """
+ Checks that `git` and `git-lfs` can be run.
+
+ Raises:
+ `EnvironmentError`: if `git` or `git-lfs` are not installed.
+ """
+ try:
+ git_version = run_subprocess('git --version'.split(),
+ self.local_dir).stdout.strip()
+ except FileNotFoundError:
+ raise EnvironmentError(
+ 'Looks like you do not have git installed, please install.')
+
+ try:
+ lfs_version = run_subprocess('git-lfs --version'.split(),
+ self.local_dir).stdout.strip()
+ except FileNotFoundError:
+ raise EnvironmentError(
+ 'Looks like you do not have git-lfs installed, please install.'
+ ' You can install from https://git-lfs.github.com/.'
+ ' Then run `git lfs install` (you only have to do this once).')
+ logger.info(git_version + '\n' + lfs_version)
+
+ def get_repo_url(self) -> str:
+ """
+ Get repo url to clone, according whether the repo is private or not
+ """
+ url = None
+
+ if self.private:
+ url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}'
+ else:
+ url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}'
+
+ if not url:
+ raise ValueError(
+ 'Empty repo url, please check clone_from parameter')
+
+ logger.debug('url to clone: %s', str(url))
+
+ return url
+
+
+def is_git_repo(folder: Union[str, Path]) -> bool:
+ """
+ Check if the folder is the root or part of a git repository
+
+ Args:
+ folder (`str`):
+ The folder in which to run the command.
+
+ Returns:
+ `bool`: `True` if the repository is part of a repository, `False`
+ otherwise.
+ """
+ folder_exists = os.path.exists(os.path.join(folder, '.git'))
+ git_branch = subprocess.run(
+ 'git branch'.split(),
+ cwd=folder,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ return folder_exists and git_branch.returncode == 0
diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py
new file mode 100644
index 00000000..90d850f4
--- /dev/null
+++ b/modelscope/hub/snapshot_download.py
@@ -0,0 +1,125 @@
+import os
+import tempfile
+from glob import glob
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from modelscope.utils.logger import get_logger
+from .api import HubApi, ModelScopeConfig
+from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR
+from .errors import NotExistError, RequestError, raise_on_error
+from .file_download import (get_file_download_url, http_get_file,
+ http_user_agent)
+from .utils.caching import ModelFileSystemCache
+from .utils.utils import get_cache_dir, model_id_to_group_owner_name
+
+logger = get_logger()
+
+
+def snapshot_download(model_id: str,
+ revision: Optional[str] = 'master',
+ cache_dir: Union[str, Path, None] = None,
+ user_agent: Optional[Union[Dict, str]] = None,
+ local_files_only: Optional[bool] = False,
+ private: Optional[bool] = False) -> str:
+ """Download all files of a repo.
+ Downloads a whole snapshot of a repo's files at the specified revision. This
+ is useful when you want all files from a repo, because you don't know which
+ ones you will need a priori. All files are nested inside a folder in order
+ to keep their actual filename relative to that folder.
+
+ An alternative would be to just clone a repo but this would require that the
+ user always has git and git-lfs installed, and properly configured.
+ Args:
+ model_id (`str`):
+ A user or an organization name and a repo name separated by a `/`.
+ revision (`str`, *optional*):
+ An optional Git revision id which can be a branch name, a tag, or a
+ commit hash. NOTE: currently only branch and tag name is supported
+ cache_dir (`str`, `Path`, *optional*):
+ Path to the folder where cached files are stored.
+ user_agent (`str`, `dict`, *optional*):
+ The user-agent info in the form of a dictionary or a string.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, avoid downloading the file and return the path to the
+ local cached file if it exists.
+ Returns:
+ Local folder path (string) of repo snapshot
+
+
+ Raises the following errors:
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
+ if `use_auth_token=True` and the token cannot be found.
+ - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
+ ETag cannot be determined.
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+ if some parameter value is invalid
+
+ """
+
+ if cache_dir is None:
+ cache_dir = get_cache_dir()
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ group_or_owner, name = model_id_to_group_owner_name(model_id)
+
+ cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
+ if local_files_only:
+ if len(cache.cached_files) == 0:
+ raise ValueError(
+ 'Cannot find the requested files in the cached path and outgoing'
+ ' traffic has been disabled. To enable model look-ups and downloads'
+ " online, set 'local_files_only' to False.")
+ logger.warn('We can not confirm the cached file is for revision: %s'
+ % revision)
+ return cache.get_root_location(
+ ) # we can not confirm the cached file is for snapshot 'revision'
+ else:
+ # make headers
+ headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
+ _api = HubApi()
+ # get file list from model repo
+ branches, tags = _api.get_model_branches_and_tags(model_id)
+ if revision not in branches and revision not in tags:
+ raise NotExistError('The specified branch or tag : %s not exist!'
+ % revision)
+
+ model_files = _api.get_model_files(
+ model_id=model_id,
+ revision=revision,
+ recursive=True,
+ use_cookies=private)
+
+ cookies = None
+ if private:
+ cookies = ModelScopeConfig.get_cookies()
+
+ for model_file in model_files:
+ if model_file['Type'] == 'tree':
+ continue
+ # check model_file is exist in cache, if exist, skip download, otherwise download
+ if cache.exists(model_file):
+ logger.info(
+ 'The specified file is in cache, skip downloading!')
+ continue
+
+ # get download url
+ url = get_file_download_url(
+ model_id=model_id,
+ file_path=model_file['Path'],
+ revision=revision)
+
+ # First download to /tmp
+ http_get_file(
+ url=url,
+ local_dir=tempfile.gettempdir(),
+ file_name=model_file['Name'],
+ headers=headers,
+ cookies=None if cookies is None else cookies.get_dict())
+ # put file to cache
+ cache.put_file(
+ model_file,
+ os.path.join(tempfile.gettempdir(), model_file['Name']))
+
+ return os.path.join(cache.get_root_location())
diff --git a/modelscope/hub/utils/__init__.py b/modelscope/hub/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modelscope/hub/utils/_subprocess.py b/modelscope/hub/utils/_subprocess.py
new file mode 100644
index 00000000..77e9fc48
--- /dev/null
+++ b/modelscope/hub/utils/_subprocess.py
@@ -0,0 +1,40 @@
+import subprocess
+from typing import List
+
+
+def run_subprocess(command: List[str],
+ folder: str,
+ check=True,
+ **kwargs) -> subprocess.CompletedProcess:
+ """
+ Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
+ please call `subprocess.run` manually in case you would like for them not to
+ be captured.
+
+ Args:
+ command (`List[str]`):
+ The command to execute as a list of strings.
+ folder (`str`):
+ The folder in which to run the command.
+ check (`bool`, *optional*, defaults to `True`):
+ Setting `check` to `True` will raise a `subprocess.CalledProcessError`
+ when the subprocess has a non-zero exit code.
+ kwargs (`Dict[str]`):
+ Keyword arguments to be passed to the `subprocess.run` underlying command.
+
+ Returns:
+ `subprocess.CompletedProcess`: The completed process.
+ """
+ if isinstance(command, str):
+ raise ValueError(
+ '`run_subprocess` should be called with a list of strings.')
+
+ return subprocess.run(
+ command,
+ stderr=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ check=check,
+ encoding='utf-8',
+ cwd=folder,
+ **kwargs,
+ )
diff --git a/modelscope/hub/utils/caching.py b/modelscope/hub/utils/caching.py
new file mode 100644
index 00000000..ac258385
--- /dev/null
+++ b/modelscope/hub/utils/caching.py
@@ -0,0 +1,294 @@
+import hashlib
+import logging
+import os
+import pickle
+import tempfile
+import time
+from shutil import move, rmtree
+
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+class FileSystemCache(object):
+ KEY_FILE_NAME = '.msc'
+ """Local file cache.
+ """
+
+ def __init__(
+ self,
+ cache_root_location: str,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ cache_location: str
+ The root location to store files.
+ """
+ os.makedirs(cache_root_location, exist_ok=True)
+ self.cache_root_location = cache_root_location
+ self.load_cache()
+
+ def get_root_location(self):
+ return self.cache_root_location
+
+ def load_cache(self):
+ """Read set of stored blocks from file
+ Args:
+ owner(`str`): individual or group username at modelscope, can be empty for official models
+ name(`str`): name of the model
+ Returns:
+ The model details information.
+ Raises:
+ NotExistError: If the model is not exist, will throw NotExistError
+ TODO: Error based error code.
+
+ model_id = {owner}/{name}
+
+ """
+ self.cached_files = []
+ cache_keys_file_path = os.path.join(self.cache_root_location,
+ FileSystemCache.KEY_FILE_NAME)
+ if os.path.exists(cache_keys_file_path):
+ with open(cache_keys_file_path, 'rb') as f:
+ self.cached_files = pickle.load(f)
+
+ def save_cached_files(self):
+ """Save cache metadata."""
+ # save new meta to tmp and move to KEY_FILE_NAME
+ cache_keys_file_path = os.path.join(self.cache_root_location,
+ FileSystemCache.KEY_FILE_NAME)
+ # TODO: Sync file write
+ fd, fn = tempfile.mkstemp()
+ with open(fd, 'wb') as f:
+ pickle.dump(self.cached_files, f)
+ move(fn, cache_keys_file_path)
+
+ def get_file(self, key):
+ """Check the key is in the cache, if exist, return the file, otherwise return None.
+ Args:
+ key(`str`): The cache key.
+ Returns:
+ If file exist, return the cached file location, otherwise None.
+ Raises:
+ None
+
+ model_id = {owner}/{name}
+
+ """
+ pass
+
+ def put_file(self, key, location):
+ """Put file to the cache,
+ Args:
+ key(`str`): The cache key
+ location(`str`): Location of the file, we will move the file to cache.
+ Returns:
+ The cached file path of the file.
+ Raises:
+ None
+
+ model_id = {owner}/{name}
+
+ """
+ pass
+
+ def remove_key(self, key):
+ """Remove cache key in index, The file is removed manually
+
+ Args:
+ key (dict): The cache key.
+ """
+ self.cached_files.remove(key)
+ self.save_cached_files()
+
+ def exists(self, key):
+ for cache_file in self.cached_files:
+ if cache_file == key:
+ return True
+
+ return False
+
+ def clear_cache(self):
+ """Remove all files and metadat from the cache
+
+ In the case of multiple cache locations, this clears only the last one,
+ which is assumed to be the read/write one.
+ """
+ rmtree(self.cache_root_location)
+ self.load_cache()
+
+ def hash_name(self, key):
+ return hashlib.sha256(key.encode()).hexdigest()
+
+
+class ModelFileSystemCache(FileSystemCache):
+ """Local cache file layout
+ cache_root/owner/model_name/|individual cached files
+ |.mk: file, The cache index file
+ Save only one version for each file.
+ """
+
+ def __init__(self, cache_root, owner, name):
+ """Put file to the cache
+ Args:
+ cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/)
+ owner(`str`): The model owner.
+ name('str'): The name of the model
+ branch('str'): The branch of model
+ tag('str'): The tag of model
+ Returns:
+ Raises:
+ None
+
+ model_id = {owner}/{name}
+
+ """
+ super().__init__(os.path.join(cache_root, owner, name))
+
+ def get_file_by_path(self, file_path):
+ """Retrieve the cache if there is file match the path.
+ Args:
+ file_path (str): The file path in the model.
+ Returns:
+ path: the full path of the file.
+ """
+ for cached_file in self.cached_files:
+ if file_path == cached_file['Path']:
+ cached_file_path = os.path.join(self.cache_root_location,
+ cached_file['Path'])
+ if os.path.exists(cached_file_path):
+ return cached_file_path
+ else:
+ self.remove_key(cached_file)
+
+ return None
+
+ def get_file_by_path_and_commit_id(self, file_path, commit_id):
+ """Retrieve the cache if there is file match the path.
+ Args:
+ file_path (str): The file path in the model.
+ commit_id (str): The commit id of the file
+ Returns:
+ path: the full path of the file.
+ """
+ for cached_file in self.cached_files:
+ if file_path == cached_file['Path'] and \
+ (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
+ cached_file_path = os.path.join(self.cache_root_location,
+ cached_file['Path'])
+ if os.path.exists(cached_file_path):
+ return cached_file_path
+ else:
+ self.remove_key(cached_file)
+
+ return None
+
+ def get_file_by_info(self, model_file_info):
+ """Check if exist cache file.
+
+ Args:
+ model_file_info (ModelFileInfo): The file information of the file.
+
+ Returns:
+ _type_: _description_
+ """
+ cache_key = self.__get_cache_key(model_file_info)
+ for cached_file in self.cached_files:
+ if cached_file == cache_key:
+ orig_path = os.path.join(self.cache_root_location,
+ cached_file['Path'])
+ if os.path.exists(orig_path):
+ return orig_path
+ else:
+ self.remove_key(cached_file)
+
+ return None
+
+ def __get_cache_key(self, model_file_info):
+ cache_key = {
+ 'Path': model_file_info['Path'],
+ 'Revision': model_file_info['Revision'], # commit id
+ }
+ return cache_key
+
+ def exists(self, model_file_info):
+ """Check the file is cached or not.
+
+ Args:
+ model_file_info (CachedFileInfo): The cached file info
+
+ Returns:
+ bool: If exists return True otherwise False
+ """
+ key = self.__get_cache_key(model_file_info)
+ is_exists = False
+ for cached_key in self.cached_files:
+ if cached_key['Path'] == key['Path'] and (
+ cached_key['Revision'].startswith(key['Revision'])
+ or key['Revision'].startswith(cached_key['Revision'])):
+ is_exists = True
+ file_path = os.path.join(self.cache_root_location,
+ model_file_info['Path'])
+ if is_exists:
+ if os.path.exists(file_path):
+ return True
+ else:
+ self.remove_key(
+ model_file_info) # sameone may manual delete the file
+ return False
+
+ def remove_if_exists(self, model_file_info):
+ """We in cache, remove it.
+
+ Args:
+ model_file_info (ModelFileInfo): The model file information from server.
+ """
+ for cached_file in self.cached_files:
+ if cached_file['Path'] == model_file_info['Path']:
+ self.remove_key(cached_file)
+ file_path = os.path.join(self.cache_root_location,
+ cached_file['Path'])
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ def put_file(self, model_file_info, model_file_location):
+ """Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
+
+ Args:
+ model_file_info (str): The file description returned by get_model_files
+ sample:
+ {
+ "CommitMessage": "add model\n",
+ "CommittedDate": 1654857567,
+ "CommitterName": "mulin.lyh",
+ "IsLFS": false,
+ "Mode": "100644",
+ "Name": "resnet18.pth",
+ "Path": "resnet18.pth",
+ "Revision": "09b68012b27de0048ba74003690a890af7aff192",
+ "Size": 46827520,
+ "Type": "blob"
+ }
+ model_file_location (str): The location of the temporary file.
+ Raises:
+ NotImplementedError: _description_
+
+ Returns:
+ str: The location of the cached file.
+ """
+ self.remove_if_exists(model_file_info) # backup old revision
+ cache_key = self.__get_cache_key(model_file_info)
+ cache_full_path = os.path.join(
+ self.cache_root_location,
+ cache_key['Path']) # Branch and Tag do not have same name.
+ cache_file_dir = os.path.dirname(cache_full_path)
+ if not os.path.exists(cache_file_dir):
+ os.makedirs(cache_file_dir, exist_ok=True)
+ # We can't make operation transaction
+ move(model_file_location, cache_full_path)
+ self.cached_files.append(cache_key)
+ self.save_cached_files()
+ return cache_full_path
diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py
new file mode 100644
index 00000000..d0704de8
--- /dev/null
+++ b/modelscope/hub/utils/utils.py
@@ -0,0 +1,39 @@
+import os
+
+from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
+ DEFAULT_MODELSCOPE_GITLAB_DOMAIN,
+ DEFAULT_MODELSCOPE_GROUP,
+ MODEL_ID_SEPARATOR,
+ MODELSCOPE_URL_SCHEME)
+
+
+def model_id_to_group_owner_name(model_id):
+ if MODEL_ID_SEPARATOR in model_id:
+ group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
+ name = model_id.split(MODEL_ID_SEPARATOR)[1]
+ else:
+ group_or_owner = DEFAULT_MODELSCOPE_GROUP
+ name = model_id
+ return group_or_owner, name
+
+
+def get_cache_dir():
+ """
+ cache dir precedence:
+ function parameter > enviroment > ~/.cache/modelscope/hub
+ """
+ default_cache_dir = os.path.expanduser(
+ os.path.join('~/.cache', 'modelscope'))
+ return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir,
+ 'hub'))
+
+
+def get_endpoint():
+ modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
+ DEFAULT_MODELSCOPE_DOMAIN)
+ return MODELSCOPE_URL_SCHEME + modelscope_domain
+
+
+def get_gitlab_domain():
+ return os.getenv('MODELSCOPE_GITLAB_DOMAIN',
+ DEFAULT_MODELSCOPE_GITLAB_DOMAIN)
diff --git a/modelscope/models/base.py b/modelscope/models/base.py
index ab0d22cc..99309a7e 100644
--- a/modelscope/models/base.py
+++ b/modelscope/models/base.py
@@ -4,12 +4,10 @@ import os.path as osp
from abc import ABC, abstractmethod
from typing import Dict, Union
-from maas_hub.snapshot_download import snapshot_download
-
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
-from modelscope.utils.hub import get_model_cache_dir
Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -47,9 +45,7 @@ class Model(ABC):
if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path
else:
- cache_path = get_model_cache_dir(model_name_or_path)
- local_model_dir = cache_path if osp.exists(
- cache_path) else snapshot_download(model_name_or_path)
+ local_model_dir = snapshot_download(model_name_or_path)
# else:
# raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists')
diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py
index 5fd1aa21..7a21d5d9 100644
--- a/modelscope/pipelines/base.py
+++ b/modelscope/pipelines/base.py
@@ -4,13 +4,11 @@ import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Union
-from maas_hub.snapshot_download import snapshot_download
-
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.base import Model
from modelscope.preprocessors import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.config import Config
-from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.logger import get_logger
from .outputs import TASK_OUTPUTS
from .util import is_model_name
@@ -32,9 +30,7 @@ class Pipeline(ABC):
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model
if isinstance(model, str) and model.startswith('damo/'):
if not osp.exists(model):
- cache_path = get_model_cache_dir(model)
- model = cache_path if osp.exists(
- cache_path) else snapshot_download(model)
+ model = snapshot_download(model)
return Model.from_pretrained(model) if is_model_name(
model) else model
elif isinstance(model, Model):
diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py
index 37c9c929..6fe6e9fd 100644
--- a/modelscope/pipelines/util.py
+++ b/modelscope/pipelines/util.py
@@ -2,8 +2,7 @@
import os.path as osp
from typing import List, Union
-from maas_hub.file_download import model_file_download
-
+from modelscope.hub.file_download import model_file_download
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
diff --git a/modelscope/preprocessors/multi_model.py b/modelscope/preprocessors/multi_model.py
index de211611..ea2e7493 100644
--- a/modelscope/preprocessors/multi_model.py
+++ b/modelscope/preprocessors/multi_model.py
@@ -4,11 +4,10 @@ from typing import Any, Dict, Union
import numpy as np
import torch
-from maas_hub.snapshot_download import snapshot_download
from PIL import Image
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import Fields, ModelFile
-from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS
@@ -34,9 +33,7 @@ class OfaImageCaptionPreprocessor(Preprocessor):
if osp.exists(model_dir):
local_model_dir = model_dir
else:
- cache_path = get_model_cache_dir(model_dir)
- local_model_dir = cache_path if osp.exists(
- cache_path) else snapshot_download(model_dir)
+ local_model_dir = snapshot_download(model_dir)
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE)
bpe_dir = local_model_dir
diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py
index 2f61b148..245642d1 100644
--- a/modelscope/utils/hub.py
+++ b/modelscope/utils/hub.py
@@ -2,13 +2,10 @@
import os
-from maas_hub.constants import MODEL_ID_SEPARATOR
+from modelscope.hub.constants import MODEL_ID_SEPARATOR
+from modelscope.hub.utils.utils import get_cache_dir
# temp solution before the hub-cache is in place
-def get_model_cache_dir(model_id: str, branch: str = 'master'):
- model_id_expanded = model_id.replace('/',
- MODEL_ID_SEPARATOR) + '.' + branch
- default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
- return os.getenv('MAAS_CACHE',
- os.path.join(default_cache_dir, 'hub', model_id_expanded))
+def get_model_cache_dir(model_id: str):
+ return os.path.join(get_cache_dir(), model_id)
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index e97352aa..6580de53 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,13 +1,16 @@
addict
datasets
easydict
-https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl
+filelock>=3.3.0
numpy
opencv-python-headless
Pillow>=6.2.0
pyyaml
requests
+requests==2.27.1
scipy
+setuptools==58.0.4
tokenizers<=0.10.3
+tqdm>=4.64.0
transformers<=4.16.2
yapf
diff --git a/tests/hub/__init__.py b/tests/hub/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py
new file mode 100644
index 00000000..2277860b
--- /dev/null
+++ b/tests/hub/test_hub_operation.py
@@ -0,0 +1,157 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import os.path as osp
+import subprocess
+import tempfile
+import unittest
+import uuid
+
+from modelscope.hub.api import HubApi, ModelScopeConfig
+from modelscope.hub.file_download import model_file_download
+from modelscope.hub.repository import Repository
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.hub.utils.utils import get_gitlab_domain
+
+USER_NAME = 'maasadmin'
+PASSWORD = '12345678'
+
+model_chinese_name = '达摩卡通化模型'
+model_org = 'unittest'
+DEFAULT_GIT_PATH = 'git'
+
+
+class GitError(Exception):
+ pass
+
+
+# TODO make thest git operation to git library after merge code.
+def run_git_command(git_path, *args) -> subprocess.CompletedProcess:
+ response = subprocess.run([git_path, *args], capture_output=True)
+ try:
+ response.check_returncode()
+ return response.stdout.decode('utf8')
+ except subprocess.CalledProcessError as error:
+ raise GitError(error.stderr.decode('utf8'))
+
+
+# for public project, token can None, private repo, there must token.
+def clone(local_dir: str, token: str, url: str):
+ url = url.replace('//', '//oauth2:%s@' % token)
+ clone_args = '-C %s clone %s' % (local_dir, url)
+ clone_args = clone_args.split(' ')
+ stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args)
+ print('stdout: %s' % stdout)
+
+
+def push(local_dir: str, token: str, url: str):
+ url = url.replace('//', '//oauth2:%s@' % token)
+ push_args = '-C %s push %s' % (local_dir, url)
+ push_args = push_args.split(' ')
+ stdout = run_git_command(DEFAULT_GIT_PATH, *push_args)
+ print('stdout: %s' % stdout)
+
+
+sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
+download_model_file_name = 'mnist-12.onnx'
+
+
+class HubOperationTest(unittest.TestCase):
+
+ def setUp(self):
+ self.old_cwd = os.getcwd()
+ self.api = HubApi()
+ # note this is temporary before official account management is ready
+ self.api.login(USER_NAME, PASSWORD)
+ self.model_name = uuid.uuid4().hex
+ self.model_id = '%s/%s' % (model_org, self.model_name)
+ self.api.create_model(
+ model_id=self.model_id,
+ chinese_name=model_chinese_name,
+ visibility=5, # 1-private, 5-public
+ license='apache-2.0')
+
+ def tearDown(self):
+ os.chdir(self.old_cwd)
+ self.api.delete_model(model_id=self.model_id)
+
+ def test_model_repo_creation(self):
+ # change to proper model names before use
+ try:
+ info = self.api.get_model(model_id=self.model_id)
+ assert info['Name'] == self.model_name
+ except KeyError as ke:
+ if ke.args[0] == 'name':
+ print(f'model {self.model_name} already exists, ignore')
+ else:
+ raise
+
+ # Note that this can be done via git operation once model repo
+ # has been created. Git-Op is the RECOMMENDED model upload approach
+ def test_model_upload(self):
+ url = f'http://{get_gitlab_domain()}/{self.model_id}'
+ print(url)
+ temporary_dir = tempfile.mkdtemp()
+ os.chdir(temporary_dir)
+ cmd_args = 'clone %s' % url
+ cmd_args = cmd_args.split(' ')
+ out = run_git_command('git', *cmd_args)
+ print(out)
+ repo_dir = os.path.join(temporary_dir, self.model_name)
+ os.chdir(repo_dir)
+ os.system('touch file1')
+ os.system('git add file1')
+ os.system("git commit -m 'Test'")
+ token = ModelScopeConfig.get_token()
+ push(repo_dir, token, url)
+
+ def test_download_single_file(self):
+ url = f'http://{get_gitlab_domain()}/{self.model_id}'
+ print(url)
+ temporary_dir = tempfile.mkdtemp()
+ os.chdir(temporary_dir)
+ os.system('git clone %s' % url)
+ repo_dir = os.path.join(temporary_dir, self.model_name)
+ os.chdir(repo_dir)
+ os.system('wget %s' % sample_model_url)
+ os.system('git add .')
+ os.system("git commit -m 'Add file'")
+ token = ModelScopeConfig.get_token()
+ push(repo_dir, token, url)
+ assert os.path.exists(
+ os.path.join(temporary_dir, self.model_name,
+ download_model_file_name))
+ downloaded_file = model_file_download(
+ model_id=self.model_id, file_path=download_model_file_name)
+ mdtime1 = os.path.getmtime(downloaded_file)
+ # download again
+ downloaded_file = model_file_download(
+ model_id=self.model_id, file_path=download_model_file_name)
+ mdtime2 = os.path.getmtime(downloaded_file)
+ assert mdtime1 == mdtime2
+
+ def test_snapshot_download(self):
+ url = f'http://{get_gitlab_domain()}/{self.model_id}'
+ print(url)
+ temporary_dir = tempfile.mkdtemp()
+ os.chdir(temporary_dir)
+ os.system('git clone %s' % url)
+ repo_dir = os.path.join(temporary_dir, self.model_name)
+ os.chdir(repo_dir)
+ os.system('wget %s' % sample_model_url)
+ os.system('git add .')
+ os.system("git commit -m 'Add file'")
+ token = ModelScopeConfig.get_token()
+ push(repo_dir, token, url)
+ snapshot_path = snapshot_download(model_id=self.model_id)
+ downloaded_file_path = os.path.join(snapshot_path,
+ download_model_file_name)
+ assert os.path.exists(downloaded_file_path)
+ mdtime1 = os.path.getmtime(downloaded_file_path)
+ # download again
+ snapshot_path = snapshot_download(model_id=self.model_id)
+ mdtime2 = os.path.getmtime(downloaded_file_path)
+ assert mdtime1 == mdtime2
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py
index e557ba86..751b6975 100644
--- a/tests/pipelines/test_image_matting.py
+++ b/tests/pipelines/test_image_matting.py
@@ -10,7 +10,6 @@ from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import ModelFile, Tasks
-from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level
@@ -18,11 +17,6 @@ class ImageMattingTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/cv_unet_image-matting'
- # switch to False if downloading everytime is not desired
- purge_cache = True
- if purge_cache:
- shutil.rmtree(
- get_model_cache_dir(self.model_id), ignore_errors=True)
@unittest.skip('deprecated, download model from model hub instead')
def test_run_with_direct_file_download(self):
diff --git a/tests/pipelines/test_ocr_detection.py b/tests/pipelines/test_ocr_detection.py
index 62fcedd3..986961b7 100644
--- a/tests/pipelines/test_ocr_detection.py
+++ b/tests/pipelines/test_ocr_detection.py
@@ -27,7 +27,7 @@ class OCRDetectionTest(unittest.TestCase):
print('ocr detection results: ')
print(result)
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
ocr_detection = pipeline(Tasks.ocr_detection)
self.pipeline_inference(ocr_detection, self.test_image)
diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py
index ac2ff4fb..43e585ba 100644
--- a/tests/pipelines/test_sentence_similarity.py
+++ b/tests/pipelines/test_sentence_similarity.py
@@ -2,14 +2,12 @@
import shutil
import unittest
-from maas_hub.snapshot_download import snapshot_download
-
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
-from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level
@@ -18,13 +16,6 @@ class SentenceSimilarityTest(unittest.TestCase):
sentence1 = '今天气温比昨天高么?'
sentence2 = '今天湿度比昨天高么?'
- def setUp(self) -> None:
- # switch to False if downloading everytime is not desired
- purge_cache = True
- if purge_cache:
- shutil.rmtree(
- get_model_cache_dir(self.model_id), ignore_errors=True)
-
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py
index 8b5c9468..f1369a2f 100644
--- a/tests/pipelines/test_speech_signal_process.py
+++ b/tests/pipelines/test_speech_signal_process.py
@@ -5,7 +5,6 @@ import unittest
from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-from modelscope.utils.hub import get_model_cache_dir
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav'
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav'
@@ -30,11 +29,6 @@ class SpeechSignalProcessTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
- # switch to False if downloading everytime is not desired
- purge_cache = True
- if purge_cache:
- shutil.rmtree(
- get_model_cache_dir(self.model_id), ignore_errors=True)
# A temporary hack to provide c++ lib. Download it first.
download(AEC_LIB_URL, AEC_LIB_FILE)
diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py
index bb24fece..8ecd9ed4 100644
--- a/tests/pipelines/test_text_classification.py
+++ b/tests/pipelines/test_text_classification.py
@@ -11,7 +11,6 @@ from modelscope.pipelines import SequenceClassificationPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Hubs, Tasks
-from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level
@@ -19,11 +18,6 @@ class SequenceClassificationTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/bert-base-sst2'
- # switch to False if downloading everytime is not desired
- purge_cache = True
- if purge_cache:
- shutil.rmtree(
- get_model_cache_dir(self.model_id), ignore_errors=True)
def predict(self, pipeline_ins: SequenceClassificationPipeline):
from easynlp.appzoo import load_dataset
diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py
index fbdd165f..cb5194c2 100644
--- a/tests/pipelines/test_text_generation.py
+++ b/tests/pipelines/test_text_generation.py
@@ -1,8 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
-from maas_hub.snapshot_download import snapshot_download
-
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.pipelines import TextGenerationPipeline, pipeline
diff --git a/tests/pipelines/test_word_segmentation.py b/tests/pipelines/test_word_segmentation.py
index 4ec2bf29..7c57d9ad 100644
--- a/tests/pipelines/test_word_segmentation.py
+++ b/tests/pipelines/test_word_segmentation.py
@@ -2,14 +2,12 @@
import shutil
import unittest
-from maas_hub.snapshot_download import snapshot_download
-
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import StructBertForTokenClassification
from modelscope.pipelines import WordSegmentationPipeline, pipeline
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.utils.constant import Tasks
-from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level
@@ -17,13 +15,6 @@ class WordSegmentationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base'
sentence = '今天天气不错,适合出去游玩'
- def setUp(self) -> None:
- # switch to False if downloading everytime is not desired
- purge_cache = True
- if purge_cache:
- shutil.rmtree(
- get_model_cache_dir(self.model_id), ignore_errors=True)
-
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
diff --git a/tests/run.py b/tests/run.py
index a904ba8e..38c5a897 100644
--- a/tests/run.py
+++ b/tests/run.py
@@ -61,7 +61,7 @@ if __name__ == '__main__':
parser.add_argument(
'--test_dir', default='tests', help='directory to be tested')
parser.add_argument(
- '--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0')
+ '--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0')
args = parser.parse_args()
set_test_level(args.level)
logger.info(f'TEST LEVEL: {test_level()}')
diff --git a/tests/utils/test_hub_operation.py b/tests/utils/test_hub_operation.py
deleted file mode 100644
index f432a60c..00000000
--- a/tests/utils/test_hub_operation.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import os.path as osp
-import unittest
-
-from maas_hub.maas_api import MaasApi
-from maas_hub.repository import Repository
-
-USER_NAME = 'maasadmin'
-PASSWORD = '12345678'
-
-
-class HubOperationTest(unittest.TestCase):
-
- def setUp(self):
- self.api = MaasApi()
- # note this is temporary before official account management is ready
- self.api.login(USER_NAME, PASSWORD)
-
- @unittest.skip('to be used for local test only')
- def test_model_repo_creation(self):
- # change to proper model names before use
- model_name = 'cv_unet_person-image-cartoon_compound-models'
- model_chinese_name = '达摩卡通化模型'
- model_org = 'damo'
- try:
- self.api.create_model(
- owner=model_org,
- name=model_name,
- chinese_name=model_chinese_name,
- visibility=5, # 1-private, 5-public
- license='apache-2.0')
- # TODO: support proper name duplication checking
- except KeyError as ke:
- if ke.args[0] == 'name':
- print(f'model {self.model_name} already exists, ignore')
- else:
- raise
-
- # Note that this can be done via git operation once model repo
- # has been created. Git-Op is the RECOMMENDED model upload approach
- @unittest.skip('to be used for local test only')
- def test_model_upload(self):
- local_path = '/path/to/local/model/directory'
- assert osp.exists(local_path), 'Local model directory not exist.'
- repo = Repository(local_dir=local_path)
- repo.push_to_hub(commit_message='Upload model files')
-
-
-if __name__ == '__main__':
- unittest.main()