Browse Source

Merge remote-tracking branch 'origin/master' into feat/zero_shot_classification

master
智丞 3 years ago
parent
commit
88ee08c5dd
29 changed files with 1487 additions and 119 deletions
  1. +0
    -0
      modelscope/hub/__init__.py
  2. +265
    -0
      modelscope/hub/api.py
  3. +8
    -0
      modelscope/hub/constants.py
  4. +30
    -0
      modelscope/hub/errors.py
  5. +254
    -0
      modelscope/hub/file_download.py
  6. +82
    -0
      modelscope/hub/git.py
  7. +173
    -0
      modelscope/hub/repository.py
  8. +125
    -0
      modelscope/hub/snapshot_download.py
  9. +0
    -0
      modelscope/hub/utils/__init__.py
  10. +40
    -0
      modelscope/hub/utils/_subprocess.py
  11. +294
    -0
      modelscope/hub/utils/caching.py
  12. +39
    -0
      modelscope/hub/utils/utils.py
  13. +2
    -6
      modelscope/models/base.py
  14. +2
    -6
      modelscope/pipelines/base.py
  15. +1
    -2
      modelscope/pipelines/util.py
  16. +2
    -5
      modelscope/preprocessors/multi_model.py
  17. +4
    -7
      modelscope/utils/hub.py
  18. +4
    -1
      requirements/runtime.txt
  19. +0
    -0
      tests/hub/__init__.py
  20. +157
    -0
      tests/hub/test_hub_operation.py
  21. +0
    -6
      tests/pipelines/test_image_matting.py
  22. +1
    -1
      tests/pipelines/test_ocr_detection.py
  23. +1
    -10
      tests/pipelines/test_sentence_similarity.py
  24. +0
    -6
      tests/pipelines/test_speech_signal_process.py
  25. +0
    -6
      tests/pipelines/test_text_classification.py
  26. +1
    -2
      tests/pipelines/test_text_generation.py
  27. +1
    -10
      tests/pipelines/test_word_segmentation.py
  28. +1
    -1
      tests/run.py
  29. +0
    -50
      tests/utils/test_hub_operation.py

+ 0
- 0
modelscope/hub/__init__.py View File


+ 265
- 0
modelscope/hub/api.py View File

@@ -0,0 +1,265 @@
import imp
import os
import pickle
import subprocess
from http.cookiejar import CookieJar
from os.path import expanduser
from typing import List, Optional, Tuple, Union

import requests

from modelscope.utils.logger import get_logger
from .constants import LOGGER_NAME
from .errors import NotExistError, is_ok, raise_on_error
from .utils.utils import get_endpoint, model_id_to_group_owner_name

logger = get_logger()


class HubApi:

def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else get_endpoint()

def login(
self,
user_name: str,
password: str,
) -> tuple():
"""
Login with username and password

Args:
username(`str`): user name on modelscope
password(`str`): password

Returns:
cookies: to authenticate yourself to ModelScope open-api
gitlab token: to access private repos

<Tip>
You only have to login once within 30 days.
</Tip>

TODO: handle cookies expire

"""
path = f'{self.endpoint}/api/v1/login'
r = requests.post(
path, json={
'username': user_name,
'password': password
})
r.raise_for_status()
d = r.json()
raise_on_error(d)

token = d['Data']['AccessToken']
cookies = r.cookies

# save token and cookie
ModelScopeConfig.save_token(token)
ModelScopeConfig.save_cookies(cookies)
ModelScopeConfig.write_to_git_credential(user_name, password)

return d['Data']['AccessToken'], cookies

def create_model(self, model_id: str, chinese_name: str, visibility: int,
license: str) -> str:
"""
Create model repo at ModelScopeHub

Args:
model_id:(`str`): The model id
chinese_name(`str`): chinese name of the model
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
license(`str`): license of the model, candidates can be found at: TBA

Returns:
name of the model created

<Tip>
model_id = {owner}/{name}
</Tip>
"""
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')

path = f'{self.endpoint}/api/v1/models'
owner_or_group, name = model_id_to_group_owner_name(model_id)
r = requests.post(
path,
json={
'Path': owner_or_group,
'Name': name,
'ChineseName': chinese_name,
'Visibility': visibility,
'License': license
},
cookies=cookies)
r.raise_for_status()
raise_on_error(r.json())
d = r.json()
return d['Data']['Name']

def delete_model(self, model_id):
"""_summary_

Args:
model_id (str): The model id.
<Tip>
model_id = {owner}/{name}
</Tip>
"""
cookies = ModelScopeConfig.get_cookies()
path = f'{self.endpoint}/api/v1/models/{model_id}'

r = requests.delete(path, cookies=cookies)
r.raise_for_status()
raise_on_error(r.json())

def get_model_url(self, model_id):
return f'{self.endpoint}/api/v1/models/{model_id}.git'

def get_model(
self,
model_id: str,
revision: str = 'master',
) -> str:
"""
Get model information at modelscope_hub

Args:
model_id(`str`): The model id.
revision(`str`): revision of model
Returns:
The model details information.
Raises:
NotExistError: If the model is not exist, will throw NotExistError
<Tip>
model_id = {owner}/{name}
</Tip>
"""
cookies = ModelScopeConfig.get_cookies()
owner_or_group, name = model_id_to_group_owner_name(model_id)
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?{revision}'

r = requests.get(path, cookies=cookies)
if r.status_code == 200:
if is_ok(r.json()):
return r.json()['Data']
else:
raise NotExistError(r.json()['Message'])
else:
r.raise_for_status()

def get_model_branches_and_tags(
self,
model_id: str,
) -> Tuple[List[str], List[str]]:
cookies = ModelScopeConfig.get_cookies()

path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = requests.get(path, cookies=cookies)
r.raise_for_status()
d = r.json()
raise_on_error(d)
info = d['Data']
branches = [x['Revision'] for x in info['RevisionMap']['Branches']
] if info['RevisionMap']['Branches'] else []
tags = [x['Revision'] for x in info['RevisionMap']['Tags']
] if info['RevisionMap']['Tags'] else []
return branches, tags

def get_model_files(
self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False) -> List[dict]:

cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')

path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
if root is not None:
path = path + f'&Root={root}'

r = requests.get(path, cookies=cookies)

r.raise_for_status()
d = r.json()
raise_on_error(d)

files = []
for file in d['Data']['Files']:
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
continue

files.append(file)
return files


class ModelScopeConfig:
path_credential = expanduser('~/.modelscope/credentials')
os.makedirs(path_credential, exist_ok=True)

@classmethod
def save_cookies(cls, cookies: CookieJar):
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f:
pickle.dump(cookies, f)

@classmethod
def get_cookies(cls):
try:
with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f:
return pickle.load(f)
except FileNotFoundError:
logger.warn("Auth token does not exist, you'll get authentication \
error when downloading private model files. Please login first"
)

@classmethod
def save_token(cls, token: str):
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f:
f.write(token)

@classmethod
def get_token(cls) -> Optional[str]:
"""
Get token or None if not existent.

Returns:
`str` or `None`: The token, `None` if it doesn't exist.

"""
token = None
try:
with open(os.path.join(cls.path_credential, 'token'), 'r') as f:
token = f.read()
except FileNotFoundError:
pass
return token

@staticmethod
def write_to_git_credential(username: str, password: str):
with subprocess.Popen(
'git credential-store store'.split(),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as process:
input_username = f'username={username.lower()}'
input_password = f'password={password}'

process.stdin.write(
f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n'
.encode('utf-8'))
process.stdin.flush()

+ 8
- 0
modelscope/hub/constants.py View File

@@ -0,0 +1,8 @@
MODELSCOPE_URL_SCHEME = 'http://'
DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330'
DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102'

DEFAULT_MODELSCOPE_GROUP = 'damo'
MODEL_ID_SEPARATOR = '/'

LOGGER_NAME = 'ModelScopeHub'

+ 30
- 0
modelscope/hub/errors.py View File

@@ -0,0 +1,30 @@
class NotExistError(Exception):
pass


class RequestError(Exception):
pass


def is_ok(rsp):
""" Check the request is ok

Args:
rsp (_type_): The request response body
Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission',
'RequestId': '', 'Success': False}
Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True}
"""
return rsp['Code'] == 200 and rsp['Success']


def raise_on_error(rsp):
"""If response error, raise exception

Args:
rsp (_type_): The server response
"""
if rsp['Code'] == 200 and rsp['Success']:
return True
else:
raise RequestError(rsp['Message'])

+ 254
- 0
modelscope/hub/file_download.py View File

@@ -0,0 +1,254 @@
import copy
import fnmatch
import logging
import os
import sys
import tempfile
import time
from functools import partial
from hashlib import sha256
from pathlib import Path
from typing import BinaryIO, Dict, Optional, Union
from uuid import uuid4

import json
import requests
from filelock import FileLock
from requests.exceptions import HTTPError
from tqdm import tqdm

from modelscope import __version__
from modelscope.utils.logger import get_logger
from .api import HubApi, ModelScopeConfig
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME,
MODEL_ID_SEPARATOR)
from .errors import NotExistError, RequestError, raise_on_error
from .utils.caching import ModelFileSystemCache
from .utils.utils import (get_cache_dir, get_endpoint,
model_id_to_group_owner_name)

SESSION_ID = uuid4().hex
logger = get_logger()


def model_file_download(
model_id: str,
file_path: str,
revision: Optional[str] = 'master',
cache_dir: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
local_files_only: Optional[bool] = False,
) -> Optional[str]: # pragma: no cover
"""
Download from a given URL and cache it if it's not already present in the
local cache.

Given a URL, this function looks for the corresponding file in the local
cache. If it's not there, download it. Then return the path to the cached
file.

Args:
model_id (`str`):
The model to whom the file to be downloaded belongs.
file_path(`str`):
Path of the file to be downloaded, relative to the root of model repo
revision(`str`, *optional*):
revision of the model file to be downloaded.
Can be any of a branch, tag or commit hash, default to `master`
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
user_agent (`dict`, `str`, *optional*):
The user-agent info in the form of a dictionary or a string.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
if `False`, download the file anyway even it exists

Returns:
Local path (string) of file or if networking is off, last version of
file cached on disk.

<Tip>

Raises the following errors:

- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
if ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid

</Tip>
"""
if cache_dir is None:
cache_dir = get_cache_dir()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

group_or_owner, name = model_id_to_group_owner_name(model_id)

cache = ModelFileSystemCache(cache_dir, group_or_owner, name)

# if local_files_only is `True` and the file already exists in cached_path
# return the cached path
if local_files_only:
cached_file_path = cache.get_file_by_path(file_path)
if cached_file_path is not None:
logger.warning(
"File exists in local cache, but we're not sure it's up to date"
)
return cached_file_path
else:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable model look-ups and downloads'
" online, set 'local_files_only' to False.")

_api = HubApi()
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
branches, tags = _api.get_model_branches_and_tags(model_id)
file_to_download_info = None
is_commit_id = False
if revision in branches or revision in tags: # The revision is version or tag,
# we need to confirm the version is up to date
# we need to get the file list to check if the lateast version is cached, if so return, otherwise download
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
)

for model_file in model_files:
if model_file['Type'] == 'tree':
continue

if model_file['Path'] == file_path:
model_file['Branch'] = revision
if cache.exists(model_file):
return cache.get_file_by_info(model_file)
else:
file_to_download_info = model_file

if file_to_download_info is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, model_id))
else: # the revision is commit id.
cached_file_path = cache.get_file_by_path_and_commit_id(
file_path, revision)
if cached_file_path is not None:
logger.info('The specified file is in cache, skip downloading!')
return cached_file_path # the file is in cache.
is_commit_id = True
# we need to download again
# TODO: skip using JWT for authorization, use cookie instead
cookies = ModelScopeConfig.get_cookies()
url_to_download = get_file_download_url(model_id, file_path, revision)
file_to_download_info = {
'Path': file_path,
'Revision':
revision if is_commit_id else file_to_download_info['Revision']
}
# Prevent parallel downloads of the same file with a lock.
lock_path = cache.get_root_location() + '.lock'

with FileLock(lock_path):
temp_file_name = next(tempfile._get_candidate_names())
http_get_file(
url_to_download,
cache_dir,
temp_file_name,
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
return cache.put_file(file_to_download_info,
os.path.join(cache_dir, temp_file_name))


def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
"""Formats a user-agent string with basic info about a request.

Args:
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.

Returns:
The formatted user-agent string.
"""
ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}'

if isinstance(user_agent, dict):
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua = user_agent
return ua


def get_file_download_url(model_id: str, file_path: str, revision: str):
"""
Format file download url according to `model_id`, `revision` and `file_path`.
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
"""
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
return download_url_template.format(
endpoint=get_endpoint(),
model_id=model_id,
revision=revision,
file_path=file_path,
)


def http_get_file(
url: str,
local_dir: str,
file_name: str,
cookies: Dict[str, str],
headers: Optional[Dict[str, str]] = None,
):
"""
Download remote file. Do not gobble up errors.
This method is only used by snapshot_download, since the behavior is quite different with single file download
TODO: consolidate with http_get_file() to avoild duplicate code

Args:
url(`str`):
actual download url of the file
local_dir(`str`):
local directory where the downloaded file stores
file_name(`str`):
name of the file stored in `local_dir`
cookies(`Dict[str, str]`):
cookies used to authentication the user, which is used for downloading private repos
headers(`Optional[Dict[str, str]] = None`):
http headers to carry necessary info when requesting the remote file

"""
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)

with temp_file_manager() as temp_file:
logger.info('downloading %s to %s', url, temp_file.name)
headers = copy.deepcopy(headers)

r = requests.get(url, stream=True, headers=headers, cookies=cookies)
r.raise_for_status()

content_length = r.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None

progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=total,
initial=0,
desc='Downloading',
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()

logger.info('storing %s in cache at %s', url, local_dir)
os.replace(temp_file.name, os.path.join(local_dir, file_name))

+ 82
- 0
modelscope/hub/git.py View File

@@ -0,0 +1,82 @@
from threading import local
from tkinter.messagebox import NO
from typing import Union

from modelscope.utils.logger import get_logger
from .constants import LOGGER_NAME
from .utils._subprocess import run_subprocess

logger = get_logger


def git_clone(
local_dir: str,
repo_url: str,
):
# TODO: use "git clone" or "git lfs clone" according to git version
# TODO: print stderr when subprocess fails
run_subprocess(
f'git clone {repo_url}'.split(),
local_dir,
True,
)


def git_checkout(
local_dir: str,
revsion: str,
):
run_subprocess(f'git checkout {revsion}'.split(), local_dir)


def git_add(local_dir: str, ):
run_subprocess(
'git add .'.split(),
local_dir,
True,
)


def git_commit(local_dir: str, commit_message: str):
run_subprocess(
'git commit -v -m'.split() + [commit_message],
local_dir,
True,
)


def git_push(local_dir: str, branch: str):
# check current branch
cur_branch = git_current_branch(local_dir)
if cur_branch != branch:
logger.error(
"You're trying to push to a different branch, please double check")
return

run_subprocess(
f'git push origin {branch}'.split(),
local_dir,
True,
)


def git_current_branch(local_dir: str) -> Union[str, None]:
"""
Get current branch name

Args:
local_dir(`str`): local model repo directory

Returns
branch name you're currently on
"""
try:
process = run_subprocess(
'git rev-parse --abbrev-ref HEAD'.split(),
local_dir,
True,
)

return str(process.stdout).strip()
except Exception as e:
raise e

+ 173
- 0
modelscope/hub/repository.py View File

@@ -0,0 +1,173 @@
import os
import subprocess
from pathlib import Path
from typing import Optional, Union

from modelscope.utils.logger import get_logger
from .api import ModelScopeConfig
from .constants import MODELSCOPE_URL_SCHEME
from .git import git_add, git_checkout, git_clone, git_commit, git_push
from .utils._subprocess import run_subprocess
from .utils.utils import get_gitlab_domain

logger = get_logger()


class Repository:

def __init__(
self,
local_dir: str,
clone_from: Optional[str] = None,
auth_token: Optional[str] = None,
private: Optional[bool] = False,
revision: Optional[str] = 'master',
):
"""
Instantiate a Repository object by cloning the remote ModelScopeHub repo
Args:
local_dir(`str`):
local directory to store the model files
clone_from(`Optional[str] = None`):
model id in ModelScope-hub from which git clone
You should ignore this parameter when `local_dir` is already a git repo
auth_token(`Optional[str]`):
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
as the token is already saved when you login the first time
private(`Optional[bool]`):
whether the model is private, default to False
revision(`Optional[str]`):
revision of the model you want to clone from. Can be any of a branch, tag or commit hash
"""
logger.info('Instantiating Repository object...')

# Create local directory if not exist
os.makedirs(local_dir, exist_ok=True)
self.local_dir = os.path.join(os.getcwd(), local_dir)

self.private = private

# Check git and git-lfs installation
self.check_git_versions()

# Retrieve auth token
if not private and isinstance(auth_token, str):
logger.warning(
'cloning a public repo with a token, which will be ignored')
self.token = None
else:
if isinstance(auth_token, str):
self.token = auth_token
else:
self.token = ModelScopeConfig.get_token()

if self.token is None:
raise EnvironmentError(
'Token does not exist, the clone will fail for private repo.'
'Please login first.')

# git clone
if clone_from is not None:
self.model_id = clone_from
logger.info('cloning model repo to %s ...', self.local_dir)
git_clone(self.local_dir, self.get_repo_url())
else:
if is_git_repo(self.local_dir):
logger.debug('[Repository] is a valid git repo')
else:
raise ValueError(
'If not specifying `clone_from`, you need to pass Repository a'
' valid git clone.')

# git checkout
if isinstance(revision, str) and revision != 'master':
git_checkout(revision)

def push_to_hub(self,
commit_message: str,
revision: Optional[str] = 'master'):
"""
Push changes changes to hub

Args:
commit_message(`str`):
commit message describing the changes, it's mandatory
revision(`Optional[str]`):
remote branch you want to push to, default to `master`

<Tip>
The function complains when local and remote branch are different, please be careful
</Tip>

"""
git_add(self.local_dir)
git_commit(self.local_dir, commit_message)

logger.info('Pushing changes to repo...')
git_push(self.local_dir, revision)

# TODO: if git push fails, how to retry?

def check_git_versions(self):
"""
Checks that `git` and `git-lfs` can be run.

Raises:
`EnvironmentError`: if `git` or `git-lfs` are not installed.
"""
try:
git_version = run_subprocess('git --version'.split(),
self.local_dir).stdout.strip()
except FileNotFoundError:
raise EnvironmentError(
'Looks like you do not have git installed, please install.')

try:
lfs_version = run_subprocess('git-lfs --version'.split(),
self.local_dir).stdout.strip()
except FileNotFoundError:
raise EnvironmentError(
'Looks like you do not have git-lfs installed, please install.'
' You can install from https://git-lfs.github.com/.'
' Then run `git lfs install` (you only have to do this once).')
logger.info(git_version + '\n' + lfs_version)

def get_repo_url(self) -> str:
"""
Get repo url to clone, according whether the repo is private or not
"""
url = None

if self.private:
url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}'
else:
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}'

if not url:
raise ValueError(
'Empty repo url, please check clone_from parameter')

logger.debug('url to clone: %s', str(url))

return url


def is_git_repo(folder: Union[str, Path]) -> bool:
"""
Check if the folder is the root or part of a git repository

Args:
folder (`str`):
The folder in which to run the command.

Returns:
`bool`: `True` if the repository is part of a repository, `False`
otherwise.
"""
folder_exists = os.path.exists(os.path.join(folder, '.git'))
git_branch = subprocess.run(
'git branch'.split(),
cwd=folder,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return folder_exists and git_branch.returncode == 0

+ 125
- 0
modelscope/hub/snapshot_download.py View File

@@ -0,0 +1,125 @@
import os
import tempfile
from glob import glob
from pathlib import Path
from typing import Dict, Optional, Union

from modelscope.utils.logger import get_logger
from .api import HubApi, ModelScopeConfig
from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR
from .errors import NotExistError, RequestError, raise_on_error
from .file_download import (get_file_download_url, http_get_file,
http_user_agent)
from .utils.caching import ModelFileSystemCache
from .utils.utils import get_cache_dir, model_id_to_group_owner_name

logger = get_logger()


def snapshot_download(model_id: str,
revision: Optional[str] = 'master',
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
private: Optional[bool] = False) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.

An alternative would be to just clone a repo but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
model_id (`str`):
A user or an organization name and a repo name separated by a `/`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
user_agent (`str`, `dict`, *optional*):
The user-agent info in the form of a dictionary or a string.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
Returns:
Local folder path (string) of repo snapshot

<Tip>
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
</Tip>
"""

if cache_dir is None:
cache_dir = get_cache_dir()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

group_or_owner, name = model_id_to_group_owner_name(model_id)

cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
if local_files_only:
if len(cache.cached_files) == 0:
raise ValueError(
'Cannot find the requested files in the cached path and outgoing'
' traffic has been disabled. To enable model look-ups and downloads'
" online, set 'local_files_only' to False.")
logger.warn('We can not confirm the cached file is for revision: %s'
% revision)
return cache.get_root_location(
) # we can not confirm the cached file is for snapshot 'revision'
else:
# make headers
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
_api = HubApi()
# get file list from model repo
branches, tags = _api.get_model_branches_and_tags(model_id)
if revision not in branches and revision not in tags:
raise NotExistError('The specified branch or tag : %s not exist!'
% revision)

model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=private)

cookies = None
if private:
cookies = ModelScopeConfig.get_cookies()

for model_file in model_files:
if model_file['Type'] == 'tree':
continue
# check model_file is exist in cache, if exist, skip download, otherwise download
if cache.exists(model_file):
logger.info(
'The specified file is in cache, skip downloading!')
continue

# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)

# First download to /tmp
http_get_file(
url=url,
local_dir=tempfile.gettempdir(),
file_name=model_file['Name'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
# put file to cache
cache.put_file(
model_file,
os.path.join(tempfile.gettempdir(), model_file['Name']))

return os.path.join(cache.get_root_location())

+ 0
- 0
modelscope/hub/utils/__init__.py View File


+ 40
- 0
modelscope/hub/utils/_subprocess.py View File

@@ -0,0 +1,40 @@
import subprocess
from typing import List


def run_subprocess(command: List[str],
folder: str,
check=True,
**kwargs) -> subprocess.CompletedProcess:
"""
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
please call `subprocess.run` manually in case you would like for them not to
be captured.

Args:
command (`List[str]`):
The command to execute as a list of strings.
folder (`str`):
The folder in which to run the command.
check (`bool`, *optional*, defaults to `True`):
Setting `check` to `True` will raise a `subprocess.CalledProcessError`
when the subprocess has a non-zero exit code.
kwargs (`Dict[str]`):
Keyword arguments to be passed to the `subprocess.run` underlying command.

Returns:
`subprocess.CompletedProcess`: The completed process.
"""
if isinstance(command, str):
raise ValueError(
'`run_subprocess` should be called with a list of strings.')

return subprocess.run(
command,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=check,
encoding='utf-8',
cwd=folder,
**kwargs,
)

+ 294
- 0
modelscope/hub/utils/caching.py View File

@@ -0,0 +1,294 @@
import hashlib
import logging
import os
import pickle
import tempfile
import time
from shutil import move, rmtree

from modelscope.utils.logger import get_logger

logger = get_logger()


class FileSystemCache(object):
KEY_FILE_NAME = '.msc'
"""Local file cache.
"""

def __init__(
self,
cache_root_location: str,
**kwargs,
):
"""
Parameters
----------
cache_location: str
The root location to store files.
"""
os.makedirs(cache_root_location, exist_ok=True)
self.cache_root_location = cache_root_location
self.load_cache()

def get_root_location(self):
return self.cache_root_location

def load_cache(self):
"""Read set of stored blocks from file
Args:
owner(`str`): individual or group username at modelscope, can be empty for official models
name(`str`): name of the model
Returns:
The model details information.
Raises:
NotExistError: If the model is not exist, will throw NotExistError
TODO: Error based error code.
<Tip>
model_id = {owner}/{name}
</Tip>
"""
self.cached_files = []
cache_keys_file_path = os.path.join(self.cache_root_location,
FileSystemCache.KEY_FILE_NAME)
if os.path.exists(cache_keys_file_path):
with open(cache_keys_file_path, 'rb') as f:
self.cached_files = pickle.load(f)

def save_cached_files(self):
"""Save cache metadata."""
# save new meta to tmp and move to KEY_FILE_NAME
cache_keys_file_path = os.path.join(self.cache_root_location,
FileSystemCache.KEY_FILE_NAME)
# TODO: Sync file write
fd, fn = tempfile.mkstemp()
with open(fd, 'wb') as f:
pickle.dump(self.cached_files, f)
move(fn, cache_keys_file_path)

def get_file(self, key):
"""Check the key is in the cache, if exist, return the file, otherwise return None.
Args:
key(`str`): The cache key.
Returns:
If file exist, return the cached file location, otherwise None.
Raises:
None
<Tip>
model_id = {owner}/{name}
</Tip>
"""
pass

def put_file(self, key, location):
"""Put file to the cache,
Args:
key(`str`): The cache key
location(`str`): Location of the file, we will move the file to cache.
Returns:
The cached file path of the file.
Raises:
None
<Tip>
model_id = {owner}/{name}
</Tip>
"""
pass

def remove_key(self, key):
"""Remove cache key in index, The file is removed manually

Args:
key (dict): The cache key.
"""
self.cached_files.remove(key)
self.save_cached_files()

def exists(self, key):
for cache_file in self.cached_files:
if cache_file == key:
return True

return False

def clear_cache(self):
"""Remove all files and metadat from the cache

In the case of multiple cache locations, this clears only the last one,
which is assumed to be the read/write one.
"""
rmtree(self.cache_root_location)
self.load_cache()

def hash_name(self, key):
return hashlib.sha256(key.encode()).hexdigest()


class ModelFileSystemCache(FileSystemCache):
"""Local cache file layout
cache_root/owner/model_name/|individual cached files
|.mk: file, The cache index file
Save only one version for each file.
"""

def __init__(self, cache_root, owner, name):
"""Put file to the cache
Args:
cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/)
owner(`str`): The model owner.
name('str'): The name of the model
branch('str'): The branch of model
tag('str'): The tag of model
Returns:
Raises:
None
<Tip>
model_id = {owner}/{name}
</Tip>
"""
super().__init__(os.path.join(cache_root, owner, name))

def get_file_by_path(self, file_path):
"""Retrieve the cache if there is file match the path.
Args:
file_path (str): The file path in the model.
Returns:
path: the full path of the file.
"""
for cached_file in self.cached_files:
if file_path == cached_file['Path']:
cached_file_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(cached_file_path):
return cached_file_path
else:
self.remove_key(cached_file)

return None

def get_file_by_path_and_commit_id(self, file_path, commit_id):
"""Retrieve the cache if there is file match the path.
Args:
file_path (str): The file path in the model.
commit_id (str): The commit id of the file
Returns:
path: the full path of the file.
"""
for cached_file in self.cached_files:
if file_path == cached_file['Path'] and \
(cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
cached_file_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(cached_file_path):
return cached_file_path
else:
self.remove_key(cached_file)

return None

def get_file_by_info(self, model_file_info):
"""Check if exist cache file.

Args:
model_file_info (ModelFileInfo): The file information of the file.

Returns:
_type_: _description_
"""
cache_key = self.__get_cache_key(model_file_info)
for cached_file in self.cached_files:
if cached_file == cache_key:
orig_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(orig_path):
return orig_path
else:
self.remove_key(cached_file)

return None

def __get_cache_key(self, model_file_info):
cache_key = {
'Path': model_file_info['Path'],
'Revision': model_file_info['Revision'], # commit id
}
return cache_key

def exists(self, model_file_info):
"""Check the file is cached or not.

Args:
model_file_info (CachedFileInfo): The cached file info

Returns:
bool: If exists return True otherwise False
"""
key = self.__get_cache_key(model_file_info)
is_exists = False
for cached_key in self.cached_files:
if cached_key['Path'] == key['Path'] and (
cached_key['Revision'].startswith(key['Revision'])
or key['Revision'].startswith(cached_key['Revision'])):
is_exists = True
file_path = os.path.join(self.cache_root_location,
model_file_info['Path'])
if is_exists:
if os.path.exists(file_path):
return True
else:
self.remove_key(
model_file_info) # sameone may manual delete the file
return False

def remove_if_exists(self, model_file_info):
"""We in cache, remove it.

Args:
model_file_info (ModelFileInfo): The model file information from server.
"""
for cached_file in self.cached_files:
if cached_file['Path'] == model_file_info['Path']:
self.remove_key(cached_file)
file_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(file_path):
os.remove(file_path)

def put_file(self, model_file_info, model_file_location):
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.

Args:
model_file_info (str): The file description returned by get_model_files
sample:
{
"CommitMessage": "add model\n",
"CommittedDate": 1654857567,
"CommitterName": "mulin.lyh",
"IsLFS": false,
"Mode": "100644",
"Name": "resnet18.pth",
"Path": "resnet18.pth",
"Revision": "09b68012b27de0048ba74003690a890af7aff192",
"Size": 46827520,
"Type": "blob"
}
model_file_location (str): The location of the temporary file.
Raises:
NotImplementedError: _description_

Returns:
str: The location of the cached file.
"""
self.remove_if_exists(model_file_info) # backup old revision
cache_key = self.__get_cache_key(model_file_info)
cache_full_path = os.path.join(
self.cache_root_location,
cache_key['Path']) # Branch and Tag do not have same name.
cache_file_dir = os.path.dirname(cache_full_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir, exist_ok=True)
# We can't make operation transaction
move(model_file_location, cache_full_path)
self.cached_files.append(cache_key)
self.save_cached_files()
return cache_full_path

+ 39
- 0
modelscope/hub/utils/utils.py View File

@@ -0,0 +1,39 @@
import os

from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GITLAB_DOMAIN,
DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR,
MODELSCOPE_URL_SCHEME)


def model_id_to_group_owner_name(model_id):
if MODEL_ID_SEPARATOR in model_id:
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
name = model_id.split(MODEL_ID_SEPARATOR)[1]
else:
group_or_owner = DEFAULT_MODELSCOPE_GROUP
name = model_id
return group_or_owner, name


def get_cache_dir():
"""
cache dir precedence:
function parameter > enviroment > ~/.cache/modelscope/hub
"""
default_cache_dir = os.path.expanduser(
os.path.join('~/.cache', 'modelscope'))
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir,
'hub'))


def get_endpoint():
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
DEFAULT_MODELSCOPE_DOMAIN)
return MODELSCOPE_URL_SCHEME + modelscope_domain


def get_gitlab_domain():
return os.getenv('MODELSCOPE_GITLAB_DOMAIN',
DEFAULT_MODELSCOPE_GITLAB_DOMAIN)

+ 2
- 6
modelscope/models/base.py View File

@@ -4,12 +4,10 @@ import os.path as osp
from abc import ABC, abstractmethod
from typing import Dict, Union

from maas_hub.snapshot_download import snapshot_download

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.hub import get_model_cache_dir

Tensor = Union['torch.Tensor', 'tf.Tensor']

@@ -47,9 +45,7 @@ class Model(ABC):
if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path
else:
cache_path = get_model_cache_dir(model_name_or_path)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_name_or_path)
local_model_dir = snapshot_download(model_name_or_path)
# else:
# raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists')


+ 2
- 6
modelscope/pipelines/base.py View File

@@ -4,13 +4,11 @@ import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Union

from maas_hub.snapshot_download import snapshot_download

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.base import Model
from modelscope.preprocessors import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.config import Config
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.logger import get_logger
from .outputs import TASK_OUTPUTS
from .util import is_model_name
@@ -32,9 +30,7 @@ class Pipeline(ABC):
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model
if isinstance(model, str) and model.startswith('damo/'):
if not osp.exists(model):
cache_path = get_model_cache_dir(model)
model = cache_path if osp.exists(
cache_path) else snapshot_download(model)
model = snapshot_download(model)
return Model.from_pretrained(model) if is_model_name(
model) else model
elif isinstance(model, Model):


+ 1
- 2
modelscope/pipelines/util.py View File

@@ -2,8 +2,7 @@
import os.path as osp
from typing import List, Union

from maas_hub.file_download import model_file_download

from modelscope.hub.file_download import model_file_download
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger


+ 2
- 5
modelscope/preprocessors/multi_model.py View File

@@ -4,11 +4,10 @@ from typing import Any, Dict, Union

import numpy as np
import torch
from maas_hub.snapshot_download import snapshot_download
from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import Fields, ModelFile
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS
@@ -34,9 +33,7 @@ class OfaImageCaptionPreprocessor(Preprocessor):
if osp.exists(model_dir):
local_model_dir = model_dir
else:
cache_path = get_model_cache_dir(model_dir)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_dir)
local_model_dir = snapshot_download(model_dir)
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE)
bpe_dir = local_model_dir



+ 4
- 7
modelscope/utils/hub.py View File

@@ -2,13 +2,10 @@

import os

from maas_hub.constants import MODEL_ID_SEPARATOR
from modelscope.hub.constants import MODEL_ID_SEPARATOR
from modelscope.hub.utils.utils import get_cache_dir


# temp solution before the hub-cache is in place
def get_model_cache_dir(model_id: str, branch: str = 'master'):
model_id_expanded = model_id.replace('/',
MODEL_ID_SEPARATOR) + '.' + branch
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
return os.getenv('MAAS_CACHE',
os.path.join(default_cache_dir, 'hub', model_id_expanded))
def get_model_cache_dir(model_id: str):
return os.path.join(get_cache_dir(), model_id)

+ 4
- 1
requirements/runtime.txt View File

@@ -1,13 +1,16 @@
addict
datasets
easydict
https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl
filelock>=3.3.0
numpy
opencv-python-headless
Pillow>=6.2.0
pyyaml
requests
requests==2.27.1
scipy
setuptools==58.0.4
tokenizers<=0.10.3
tqdm>=4.64.0
transformers<=4.16.2
yapf

+ 0
- 0
tests/hub/__init__.py View File


+ 157
- 0
tests/hub/test_hub_operation.py View File

@@ -0,0 +1,157 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import subprocess
import tempfile
import unittest
import uuid

from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.utils.utils import get_gitlab_domain

USER_NAME = 'maasadmin'
PASSWORD = '12345678'

model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'
DEFAULT_GIT_PATH = 'git'


class GitError(Exception):
pass


# TODO make thest git operation to git library after merge code.
def run_git_command(git_path, *args) -> subprocess.CompletedProcess:
response = subprocess.run([git_path, *args], capture_output=True)
try:
response.check_returncode()
return response.stdout.decode('utf8')
except subprocess.CalledProcessError as error:
raise GitError(error.stderr.decode('utf8'))


# for public project, token can None, private repo, there must token.
def clone(local_dir: str, token: str, url: str):
url = url.replace('//', '//oauth2:%s@' % token)
clone_args = '-C %s clone %s' % (local_dir, url)
clone_args = clone_args.split(' ')
stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args)
print('stdout: %s' % stdout)


def push(local_dir: str, token: str, url: str):
url = url.replace('//', '//oauth2:%s@' % token)
push_args = '-C %s push %s' % (local_dir, url)
push_args = push_args.split(' ')
stdout = run_git_command(DEFAULT_GIT_PATH, *push_args)
print('stdout: %s' % stdout)


sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
download_model_file_name = 'mnist-12.onnx'


class HubOperationTest(unittest.TestCase):

def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi()
# note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD)
self.model_name = uuid.uuid4().hex
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')

def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id)

def test_model_repo_creation(self):
# change to proper model names before use
try:
info = self.api.get_model(model_id=self.model_id)
assert info['Name'] == self.model_name
except KeyError as ke:
if ke.args[0] == 'name':
print(f'model {self.model_name} already exists, ignore')
else:
raise

# Note that this can be done via git operation once model repo
# has been created. Git-Op is the RECOMMENDED model upload approach
def test_model_upload(self):
url = f'http://{get_gitlab_domain()}/{self.model_id}'
print(url)
temporary_dir = tempfile.mkdtemp()
os.chdir(temporary_dir)
cmd_args = 'clone %s' % url
cmd_args = cmd_args.split(' ')
out = run_git_command('git', *cmd_args)
print(out)
repo_dir = os.path.join(temporary_dir, self.model_name)
os.chdir(repo_dir)
os.system('touch file1')
os.system('git add file1')
os.system("git commit -m 'Test'")
token = ModelScopeConfig.get_token()
push(repo_dir, token, url)

def test_download_single_file(self):
url = f'http://{get_gitlab_domain()}/{self.model_id}'
print(url)
temporary_dir = tempfile.mkdtemp()
os.chdir(temporary_dir)
os.system('git clone %s' % url)
repo_dir = os.path.join(temporary_dir, self.model_name)
os.chdir(repo_dir)
os.system('wget %s' % sample_model_url)
os.system('git add .')
os.system("git commit -m 'Add file'")
token = ModelScopeConfig.get_token()
push(repo_dir, token, url)
assert os.path.exists(
os.path.join(temporary_dir, self.model_name,
download_model_file_name))
downloaded_file = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
mdtime1 = os.path.getmtime(downloaded_file)
# download again
downloaded_file = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
mdtime2 = os.path.getmtime(downloaded_file)
assert mdtime1 == mdtime2

def test_snapshot_download(self):
url = f'http://{get_gitlab_domain()}/{self.model_id}'
print(url)
temporary_dir = tempfile.mkdtemp()
os.chdir(temporary_dir)
os.system('git clone %s' % url)
repo_dir = os.path.join(temporary_dir, self.model_name)
os.chdir(repo_dir)
os.system('wget %s' % sample_model_url)
os.system('git add .')
os.system("git commit -m 'Add file'")
token = ModelScopeConfig.get_token()
push(repo_dir, token, url)
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
mdtime1 = os.path.getmtime(downloaded_file_path)
# download again
snapshot_path = snapshot_download(model_id=self.model_id)
mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2


if __name__ == '__main__':
unittest.main()

+ 0
- 6
tests/pipelines/test_image_matting.py View File

@@ -10,7 +10,6 @@ from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level


@@ -18,11 +17,6 @@ class ImageMattingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_unet_image-matting'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)

@unittest.skip('deprecated, download model from model hub instead')
def test_run_with_direct_file_download(self):


+ 1
- 1
tests/pipelines/test_ocr_detection.py View File

@@ -27,7 +27,7 @@ class OCRDetectionTest(unittest.TestCase):
print('ocr detection results: ')
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
ocr_detection = pipeline(Tasks.ocr_detection)
self.pipeline_inference(ocr_detection, self.test_image)


+ 1
- 10
tests/pipelines/test_sentence_similarity.py View File

@@ -2,14 +2,12 @@
import shutil
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level


@@ -18,13 +16,6 @@ class SentenceSimilarityTest(unittest.TestCase):
sentence1 = '今天气温比昨天高么?'
sentence2 = '今天湿度比昨天高么?'

def setUp(self) -> None:
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)


+ 0
- 6
tests/pipelines/test_speech_signal_process.py View File

@@ -5,7 +5,6 @@ import unittest
from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import get_model_cache_dir

NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav'
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav'
@@ -30,11 +29,6 @@ class SpeechSignalProcessTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)
# A temporary hack to provide c++ lib. Download it first.
download(AEC_LIB_URL, AEC_LIB_FILE)



+ 0
- 6
tests/pipelines/test_text_classification.py View File

@@ -11,7 +11,6 @@ from modelscope.pipelines import SequenceClassificationPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Hubs, Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level


@@ -19,11 +18,6 @@ class SequenceClassificationTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/bert-base-sst2'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)

def predict(self, pipeline_ins: SequenceClassificationPipeline):
from easynlp.appzoo import load_dataset


+ 1
- 2
tests/pipelines/test_text_generation.py View File

@@ -1,8 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.pipelines import TextGenerationPipeline, pipeline


+ 1
- 10
tests/pipelines/test_word_segmentation.py View File

@@ -2,14 +2,12 @@
import shutil
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import StructBertForTokenClassification
from modelscope.pipelines import WordSegmentationPipeline, pipeline
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level


@@ -17,13 +15,6 @@ class WordSegmentationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base'
sentence = '今天天气不错,适合出去游玩'

def setUp(self) -> None:
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)


+ 1
- 1
tests/run.py View File

@@ -61,7 +61,7 @@ if __name__ == '__main__':
parser.add_argument(
'--test_dir', default='tests', help='directory to be tested')
parser.add_argument(
'--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0')
'--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0')
args = parser.parse_args()
set_test_level(args.level)
logger.info(f'TEST LEVEL: {test_level()}')


+ 0
- 50
tests/utils/test_hub_operation.py View File

@@ -1,50 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

from maas_hub.maas_api import MaasApi
from maas_hub.repository import Repository

USER_NAME = 'maasadmin'
PASSWORD = '12345678'


class HubOperationTest(unittest.TestCase):

def setUp(self):
self.api = MaasApi()
# note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD)

@unittest.skip('to be used for local test only')
def test_model_repo_creation(self):
# change to proper model names before use
model_name = 'cv_unet_person-image-cartoon_compound-models'
model_chinese_name = '达摩卡通化模型'
model_org = 'damo'
try:
self.api.create_model(
owner=model_org,
name=model_name,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
# TODO: support proper name duplication checking
except KeyError as ke:
if ke.args[0] == 'name':
print(f'model {self.model_name} already exists, ignore')
else:
raise

# Note that this can be done via git operation once model repo
# has been created. Git-Op is the RECOMMENDED model upload approach
@unittest.skip('to be used for local test only')
def test_model_upload(self):
local_path = '/path/to/local/model/directory'
assert osp.exists(local_path), 'Local model directory not exist.'
repo = Repository(local_dir=local_path)
repo.push_to_hub(commit_message='Upload model files')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save