@@ -0,0 +1,265 @@ | |||
import imp | |||
import os | |||
import pickle | |||
import subprocess | |||
from http.cookiejar import CookieJar | |||
from os.path import expanduser | |||
from typing import List, Optional, Tuple, Union | |||
import requests | |||
from modelscope.utils.logger import get_logger | |||
from .constants import LOGGER_NAME | |||
from .errors import NotExistError, is_ok, raise_on_error | |||
from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||
logger = get_logger() | |||
class HubApi: | |||
def __init__(self, endpoint=None): | |||
self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
def login( | |||
self, | |||
user_name: str, | |||
password: str, | |||
) -> tuple(): | |||
""" | |||
Login with username and password | |||
Args: | |||
username(`str`): user name on modelscope | |||
password(`str`): password | |||
Returns: | |||
cookies: to authenticate yourself to ModelScope open-api | |||
gitlab token: to access private repos | |||
<Tip> | |||
You only have to login once within 30 days. | |||
</Tip> | |||
TODO: handle cookies expire | |||
""" | |||
path = f'{self.endpoint}/api/v1/login' | |||
r = requests.post( | |||
path, json={ | |||
'username': user_name, | |||
'password': password | |||
}) | |||
r.raise_for_status() | |||
d = r.json() | |||
raise_on_error(d) | |||
token = d['Data']['AccessToken'] | |||
cookies = r.cookies | |||
# save token and cookie | |||
ModelScopeConfig.save_token(token) | |||
ModelScopeConfig.save_cookies(cookies) | |||
ModelScopeConfig.write_to_git_credential(user_name, password) | |||
return d['Data']['AccessToken'], cookies | |||
def create_model(self, model_id: str, chinese_name: str, visibility: int, | |||
license: str) -> str: | |||
""" | |||
Create model repo at ModelScopeHub | |||
Args: | |||
model_id:(`str`): The model id | |||
chinese_name(`str`): chinese name of the model | |||
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public) | |||
license(`str`): license of the model, candidates can be found at: TBA | |||
Returns: | |||
name of the model created | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise ValueError('Token does not exist, please login first.') | |||
path = f'{self.endpoint}/api/v1/models' | |||
owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
r = requests.post( | |||
path, | |||
json={ | |||
'Path': owner_or_group, | |||
'Name': name, | |||
'ChineseName': chinese_name, | |||
'Visibility': visibility, | |||
'License': license | |||
}, | |||
cookies=cookies) | |||
r.raise_for_status() | |||
raise_on_error(r.json()) | |||
d = r.json() | |||
return d['Data']['Name'] | |||
def delete_model(self, model_id): | |||
"""_summary_ | |||
Args: | |||
model_id (str): The model id. | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
path = f'{self.endpoint}/api/v1/models/{model_id}' | |||
r = requests.delete(path, cookies=cookies) | |||
r.raise_for_status() | |||
raise_on_error(r.json()) | |||
def get_model_url(self, model_id): | |||
return f'{self.endpoint}/api/v1/models/{model_id}.git' | |||
def get_model( | |||
self, | |||
model_id: str, | |||
revision: str = 'master', | |||
) -> str: | |||
""" | |||
Get model information at modelscope_hub | |||
Args: | |||
model_id(`str`): The model id. | |||
revision(`str`): revision of model | |||
Returns: | |||
The model details information. | |||
Raises: | |||
NotExistError: If the model is not exist, will throw NotExistError | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?{revision}' | |||
r = requests.get(path, cookies=cookies) | |||
if r.status_code == 200: | |||
if is_ok(r.json()): | |||
return r.json()['Data'] | |||
else: | |||
raise NotExistError(r.json()['Message']) | |||
else: | |||
r.raise_for_status() | |||
def get_model_branches_and_tags( | |||
self, | |||
model_id: str, | |||
) -> Tuple[List[str], List[str]]: | |||
cookies = ModelScopeConfig.get_cookies() | |||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | |||
r = requests.get(path, cookies=cookies) | |||
r.raise_for_status() | |||
d = r.json() | |||
raise_on_error(d) | |||
info = d['Data'] | |||
branches = [x['Revision'] for x in info['RevisionMap']['Branches'] | |||
] if info['RevisionMap']['Branches'] else [] | |||
tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||
] if info['RevisionMap']['Tags'] else [] | |||
return branches, tags | |||
def get_model_files( | |||
self, | |||
model_id: str, | |||
revision: Optional[str] = 'master', | |||
root: Optional[str] = None, | |||
recursive: Optional[str] = False, | |||
use_cookies: Union[bool, CookieJar] = False) -> List[dict]: | |||
cookies = None | |||
if isinstance(use_cookies, CookieJar): | |||
cookies = use_cookies | |||
elif use_cookies: | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise ValueError('Token does not exist, please login first.') | |||
path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}' | |||
if root is not None: | |||
path = path + f'&Root={root}' | |||
r = requests.get(path, cookies=cookies) | |||
r.raise_for_status() | |||
d = r.json() | |||
raise_on_error(d) | |||
files = [] | |||
for file in d['Data']['Files']: | |||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': | |||
continue | |||
files.append(file) | |||
return files | |||
class ModelScopeConfig: | |||
path_credential = expanduser('~/.modelscope/credentials') | |||
os.makedirs(path_credential, exist_ok=True) | |||
@classmethod | |||
def save_cookies(cls, cookies: CookieJar): | |||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | |||
pickle.dump(cookies, f) | |||
@classmethod | |||
def get_cookies(cls): | |||
try: | |||
with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: | |||
return pickle.load(f) | |||
except FileNotFoundError: | |||
logger.warn("Auth token does not exist, you'll get authentication \ | |||
error when downloading private model files. Please login first" | |||
) | |||
@classmethod | |||
def save_token(cls, token: str): | |||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | |||
f.write(token) | |||
@classmethod | |||
def get_token(cls) -> Optional[str]: | |||
""" | |||
Get token or None if not existent. | |||
Returns: | |||
`str` or `None`: The token, `None` if it doesn't exist. | |||
""" | |||
token = None | |||
try: | |||
with open(os.path.join(cls.path_credential, 'token'), 'r') as f: | |||
token = f.read() | |||
except FileNotFoundError: | |||
pass | |||
return token | |||
@staticmethod | |||
def write_to_git_credential(username: str, password: str): | |||
with subprocess.Popen( | |||
'git credential-store store'.split(), | |||
stdin=subprocess.PIPE, | |||
stdout=subprocess.PIPE, | |||
stderr=subprocess.STDOUT, | |||
) as process: | |||
input_username = f'username={username.lower()}' | |||
input_password = f'password={password}' | |||
process.stdin.write( | |||
f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n' | |||
.encode('utf-8')) | |||
process.stdin.flush() |
@@ -0,0 +1,8 @@ | |||
MODELSCOPE_URL_SCHEME = 'http://' | |||
DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330' | |||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102' | |||
DEFAULT_MODELSCOPE_GROUP = 'damo' | |||
MODEL_ID_SEPARATOR = '/' | |||
LOGGER_NAME = 'ModelScopeHub' |
@@ -0,0 +1,30 @@ | |||
class NotExistError(Exception): | |||
pass | |||
class RequestError(Exception): | |||
pass | |||
def is_ok(rsp): | |||
""" Check the request is ok | |||
Args: | |||
rsp (_type_): The request response body | |||
Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission', | |||
'RequestId': '', 'Success': False} | |||
Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True} | |||
""" | |||
return rsp['Code'] == 200 and rsp['Success'] | |||
def raise_on_error(rsp): | |||
"""If response error, raise exception | |||
Args: | |||
rsp (_type_): The server response | |||
""" | |||
if rsp['Code'] == 200 and rsp['Success']: | |||
return True | |||
else: | |||
raise RequestError(rsp['Message']) |
@@ -0,0 +1,254 @@ | |||
import copy | |||
import fnmatch | |||
import logging | |||
import os | |||
import sys | |||
import tempfile | |||
import time | |||
from functools import partial | |||
from hashlib import sha256 | |||
from pathlib import Path | |||
from typing import BinaryIO, Dict, Optional, Union | |||
from uuid import uuid4 | |||
import json | |||
import requests | |||
from filelock import FileLock | |||
from requests.exceptions import HTTPError | |||
from tqdm import tqdm | |||
from modelscope import __version__ | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME, | |||
MODEL_ID_SEPARATOR) | |||
from .errors import NotExistError, RequestError, raise_on_error | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import (get_cache_dir, get_endpoint, | |||
model_id_to_group_owner_name) | |||
SESSION_ID = uuid4().hex | |||
logger = get_logger() | |||
def model_file_download( | |||
model_id: str, | |||
file_path: str, | |||
revision: Optional[str] = 'master', | |||
cache_dir: Optional[str] = None, | |||
user_agent: Union[Dict, str, None] = None, | |||
local_files_only: Optional[bool] = False, | |||
) -> Optional[str]: # pragma: no cover | |||
""" | |||
Download from a given URL and cache it if it's not already present in the | |||
local cache. | |||
Given a URL, this function looks for the corresponding file in the local | |||
cache. If it's not there, download it. Then return the path to the cached | |||
file. | |||
Args: | |||
model_id (`str`): | |||
The model to whom the file to be downloaded belongs. | |||
file_path(`str`): | |||
Path of the file to be downloaded, relative to the root of model repo | |||
revision(`str`, *optional*): | |||
revision of the model file to be downloaded. | |||
Can be any of a branch, tag or commit hash, default to `master` | |||
cache_dir (`str`, `Path`, *optional*): | |||
Path to the folder where cached files are stored. | |||
user_agent (`dict`, `str`, *optional*): | |||
The user-agent info in the form of a dictionary or a string. | |||
local_files_only (`bool`, *optional*, defaults to `False`): | |||
If `True`, avoid downloading the file and return the path to the | |||
local cached file if it exists. | |||
if `False`, download the file anyway even it exists | |||
Returns: | |||
Local path (string) of file or if networking is off, last version of | |||
file cached on disk. | |||
<Tip> | |||
Raises the following errors: | |||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) | |||
if `use_auth_token=True` and the token cannot be found. | |||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) | |||
if ETag cannot be determined. | |||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) | |||
if some parameter value is invalid | |||
</Tip> | |||
""" | |||
if cache_dir is None: | |||
cache_dir = get_cache_dir() | |||
if isinstance(cache_dir, Path): | |||
cache_dir = str(cache_dir) | |||
group_or_owner, name = model_id_to_group_owner_name(model_id) | |||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name) | |||
# if local_files_only is `True` and the file already exists in cached_path | |||
# return the cached path | |||
if local_files_only: | |||
cached_file_path = cache.get_file_by_path(file_path) | |||
if cached_file_path is not None: | |||
logger.warning( | |||
"File exists in local cache, but we're not sure it's up to date" | |||
) | |||
return cached_file_path | |||
else: | |||
raise ValueError( | |||
'Cannot find the requested files in the cached path and outgoing' | |||
' traffic has been disabled. To enable model look-ups and downloads' | |||
" online, set 'local_files_only' to False.") | |||
_api = HubApi() | |||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
branches, tags = _api.get_model_branches_and_tags(model_id) | |||
file_to_download_info = None | |||
is_commit_id = False | |||
if revision in branches or revision in tags: # The revision is version or tag, | |||
# we need to confirm the version is up to date | |||
# we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||
model_files = _api.get_model_files( | |||
model_id=model_id, | |||
revision=revision, | |||
recursive=True, | |||
) | |||
for model_file in model_files: | |||
if model_file['Type'] == 'tree': | |||
continue | |||
if model_file['Path'] == file_path: | |||
model_file['Branch'] = revision | |||
if cache.exists(model_file): | |||
return cache.get_file_by_info(model_file) | |||
else: | |||
file_to_download_info = model_file | |||
if file_to_download_info is None: | |||
raise NotExistError('The file path: %s not exist in: %s' % | |||
(file_path, model_id)) | |||
else: # the revision is commit id. | |||
cached_file_path = cache.get_file_by_path_and_commit_id( | |||
file_path, revision) | |||
if cached_file_path is not None: | |||
logger.info('The specified file is in cache, skip downloading!') | |||
return cached_file_path # the file is in cache. | |||
is_commit_id = True | |||
# we need to download again | |||
# TODO: skip using JWT for authorization, use cookie instead | |||
cookies = ModelScopeConfig.get_cookies() | |||
url_to_download = get_file_download_url(model_id, file_path, revision) | |||
file_to_download_info = { | |||
'Path': file_path, | |||
'Revision': | |||
revision if is_commit_id else file_to_download_info['Revision'] | |||
} | |||
# Prevent parallel downloads of the same file with a lock. | |||
lock_path = cache.get_root_location() + '.lock' | |||
with FileLock(lock_path): | |||
temp_file_name = next(tempfile._get_candidate_names()) | |||
http_get_file( | |||
url_to_download, | |||
cache_dir, | |||
temp_file_name, | |||
headers=headers, | |||
cookies=None if cookies is None else cookies.get_dict()) | |||
return cache.put_file(file_to_download_info, | |||
os.path.join(cache_dir, temp_file_name)) | |||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
"""Formats a user-agent string with basic info about a request. | |||
Args: | |||
user_agent (`str`, `dict`, *optional*): | |||
The user agent info in the form of a dictionary or a single string. | |||
Returns: | |||
The formatted user-agent string. | |||
""" | |||
ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||
if isinstance(user_agent, dict): | |||
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
elif isinstance(user_agent, str): | |||
ua = user_agent | |||
return ua | |||
def get_file_download_url(model_id: str, file_path: str, revision: str): | |||
""" | |||
Format file download url according to `model_id`, `revision` and `file_path`. | |||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, | |||
the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md | |||
""" | |||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' | |||
return download_url_template.format( | |||
endpoint=get_endpoint(), | |||
model_id=model_id, | |||
revision=revision, | |||
file_path=file_path, | |||
) | |||
def http_get_file( | |||
url: str, | |||
local_dir: str, | |||
file_name: str, | |||
cookies: Dict[str, str], | |||
headers: Optional[Dict[str, str]] = None, | |||
): | |||
""" | |||
Download remote file. Do not gobble up errors. | |||
This method is only used by snapshot_download, since the behavior is quite different with single file download | |||
TODO: consolidate with http_get_file() to avoild duplicate code | |||
Args: | |||
url(`str`): | |||
actual download url of the file | |||
local_dir(`str`): | |||
local directory where the downloaded file stores | |||
file_name(`str`): | |||
name of the file stored in `local_dir` | |||
cookies(`Dict[str, str]`): | |||
cookies used to authentication the user, which is used for downloading private repos | |||
headers(`Optional[Dict[str, str]] = None`): | |||
http headers to carry necessary info when requesting the remote file | |||
""" | |||
temp_file_manager = partial( | |||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | |||
with temp_file_manager() as temp_file: | |||
logger.info('downloading %s to %s', url, temp_file.name) | |||
headers = copy.deepcopy(headers) | |||
r = requests.get(url, stream=True, headers=headers, cookies=cookies) | |||
r.raise_for_status() | |||
content_length = r.headers.get('Content-Length') | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm( | |||
unit='B', | |||
unit_scale=True, | |||
unit_divisor=1024, | |||
total=total, | |||
initial=0, | |||
desc='Downloading', | |||
) | |||
for chunk in r.iter_content(chunk_size=1024): | |||
if chunk: # filter out keep-alive new chunks | |||
progress.update(len(chunk)) | |||
temp_file.write(chunk) | |||
progress.close() | |||
logger.info('storing %s in cache at %s', url, local_dir) | |||
os.replace(temp_file.name, os.path.join(local_dir, file_name)) |
@@ -0,0 +1,82 @@ | |||
from threading import local | |||
from tkinter.messagebox import NO | |||
from typing import Union | |||
from modelscope.utils.logger import get_logger | |||
from .constants import LOGGER_NAME | |||
from .utils._subprocess import run_subprocess | |||
logger = get_logger | |||
def git_clone( | |||
local_dir: str, | |||
repo_url: str, | |||
): | |||
# TODO: use "git clone" or "git lfs clone" according to git version | |||
# TODO: print stderr when subprocess fails | |||
run_subprocess( | |||
f'git clone {repo_url}'.split(), | |||
local_dir, | |||
True, | |||
) | |||
def git_checkout( | |||
local_dir: str, | |||
revsion: str, | |||
): | |||
run_subprocess(f'git checkout {revsion}'.split(), local_dir) | |||
def git_add(local_dir: str, ): | |||
run_subprocess( | |||
'git add .'.split(), | |||
local_dir, | |||
True, | |||
) | |||
def git_commit(local_dir: str, commit_message: str): | |||
run_subprocess( | |||
'git commit -v -m'.split() + [commit_message], | |||
local_dir, | |||
True, | |||
) | |||
def git_push(local_dir: str, branch: str): | |||
# check current branch | |||
cur_branch = git_current_branch(local_dir) | |||
if cur_branch != branch: | |||
logger.error( | |||
"You're trying to push to a different branch, please double check") | |||
return | |||
run_subprocess( | |||
f'git push origin {branch}'.split(), | |||
local_dir, | |||
True, | |||
) | |||
def git_current_branch(local_dir: str) -> Union[str, None]: | |||
""" | |||
Get current branch name | |||
Args: | |||
local_dir(`str`): local model repo directory | |||
Returns | |||
branch name you're currently on | |||
""" | |||
try: | |||
process = run_subprocess( | |||
'git rev-parse --abbrev-ref HEAD'.split(), | |||
local_dir, | |||
True, | |||
) | |||
return str(process.stdout).strip() | |||
except Exception as e: | |||
raise e |
@@ -0,0 +1,173 @@ | |||
import os | |||
import subprocess | |||
from pathlib import Path | |||
from typing import Optional, Union | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from .constants import MODELSCOPE_URL_SCHEME | |||
from .git import git_add, git_checkout, git_clone, git_commit, git_push | |||
from .utils._subprocess import run_subprocess | |||
from .utils.utils import get_gitlab_domain | |||
logger = get_logger() | |||
class Repository: | |||
def __init__( | |||
self, | |||
local_dir: str, | |||
clone_from: Optional[str] = None, | |||
auth_token: Optional[str] = None, | |||
private: Optional[bool] = False, | |||
revision: Optional[str] = 'master', | |||
): | |||
""" | |||
Instantiate a Repository object by cloning the remote ModelScopeHub repo | |||
Args: | |||
local_dir(`str`): | |||
local directory to store the model files | |||
clone_from(`Optional[str] = None`): | |||
model id in ModelScope-hub from which git clone | |||
You should ignore this parameter when `local_dir` is already a git repo | |||
auth_token(`Optional[str]`): | |||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
as the token is already saved when you login the first time | |||
private(`Optional[bool]`): | |||
whether the model is private, default to False | |||
revision(`Optional[str]`): | |||
revision of the model you want to clone from. Can be any of a branch, tag or commit hash | |||
""" | |||
logger.info('Instantiating Repository object...') | |||
# Create local directory if not exist | |||
os.makedirs(local_dir, exist_ok=True) | |||
self.local_dir = os.path.join(os.getcwd(), local_dir) | |||
self.private = private | |||
# Check git and git-lfs installation | |||
self.check_git_versions() | |||
# Retrieve auth token | |||
if not private and isinstance(auth_token, str): | |||
logger.warning( | |||
'cloning a public repo with a token, which will be ignored') | |||
self.token = None | |||
else: | |||
if isinstance(auth_token, str): | |||
self.token = auth_token | |||
else: | |||
self.token = ModelScopeConfig.get_token() | |||
if self.token is None: | |||
raise EnvironmentError( | |||
'Token does not exist, the clone will fail for private repo.' | |||
'Please login first.') | |||
# git clone | |||
if clone_from is not None: | |||
self.model_id = clone_from | |||
logger.info('cloning model repo to %s ...', self.local_dir) | |||
git_clone(self.local_dir, self.get_repo_url()) | |||
else: | |||
if is_git_repo(self.local_dir): | |||
logger.debug('[Repository] is a valid git repo') | |||
else: | |||
raise ValueError( | |||
'If not specifying `clone_from`, you need to pass Repository a' | |||
' valid git clone.') | |||
# git checkout | |||
if isinstance(revision, str) and revision != 'master': | |||
git_checkout(revision) | |||
def push_to_hub(self, | |||
commit_message: str, | |||
revision: Optional[str] = 'master'): | |||
""" | |||
Push changes changes to hub | |||
Args: | |||
commit_message(`str`): | |||
commit message describing the changes, it's mandatory | |||
revision(`Optional[str]`): | |||
remote branch you want to push to, default to `master` | |||
<Tip> | |||
The function complains when local and remote branch are different, please be careful | |||
</Tip> | |||
""" | |||
git_add(self.local_dir) | |||
git_commit(self.local_dir, commit_message) | |||
logger.info('Pushing changes to repo...') | |||
git_push(self.local_dir, revision) | |||
# TODO: if git push fails, how to retry? | |||
def check_git_versions(self): | |||
""" | |||
Checks that `git` and `git-lfs` can be run. | |||
Raises: | |||
`EnvironmentError`: if `git` or `git-lfs` are not installed. | |||
""" | |||
try: | |||
git_version = run_subprocess('git --version'.split(), | |||
self.local_dir).stdout.strip() | |||
except FileNotFoundError: | |||
raise EnvironmentError( | |||
'Looks like you do not have git installed, please install.') | |||
try: | |||
lfs_version = run_subprocess('git-lfs --version'.split(), | |||
self.local_dir).stdout.strip() | |||
except FileNotFoundError: | |||
raise EnvironmentError( | |||
'Looks like you do not have git-lfs installed, please install.' | |||
' You can install from https://git-lfs.github.com/.' | |||
' Then run `git lfs install` (you only have to do this once).') | |||
logger.info(git_version + '\n' + lfs_version) | |||
def get_repo_url(self) -> str: | |||
""" | |||
Get repo url to clone, according whether the repo is private or not | |||
""" | |||
url = None | |||
if self.private: | |||
url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' | |||
else: | |||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' | |||
if not url: | |||
raise ValueError( | |||
'Empty repo url, please check clone_from parameter') | |||
logger.debug('url to clone: %s', str(url)) | |||
return url | |||
def is_git_repo(folder: Union[str, Path]) -> bool: | |||
""" | |||
Check if the folder is the root or part of a git repository | |||
Args: | |||
folder (`str`): | |||
The folder in which to run the command. | |||
Returns: | |||
`bool`: `True` if the repository is part of a repository, `False` | |||
otherwise. | |||
""" | |||
folder_exists = os.path.exists(os.path.join(folder, '.git')) | |||
git_branch = subprocess.run( | |||
'git branch'.split(), | |||
cwd=folder, | |||
stdout=subprocess.PIPE, | |||
stderr=subprocess.PIPE) | |||
return folder_exists and git_branch.returncode == 0 |
@@ -0,0 +1,125 @@ | |||
import os | |||
import tempfile | |||
from glob import glob | |||
from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR | |||
from .errors import NotExistError, RequestError, raise_on_error | |||
from .file_download import (get_file_download_url, http_get_file, | |||
http_user_agent) | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import get_cache_dir, model_id_to_group_owner_name | |||
logger = get_logger() | |||
def snapshot_download(model_id: str, | |||
revision: Optional[str] = 'master', | |||
cache_dir: Union[str, Path, None] = None, | |||
user_agent: Optional[Union[Dict, str]] = None, | |||
local_files_only: Optional[bool] = False, | |||
private: Optional[bool] = False) -> str: | |||
"""Download all files of a repo. | |||
Downloads a whole snapshot of a repo's files at the specified revision. This | |||
is useful when you want all files from a repo, because you don't know which | |||
ones you will need a priori. All files are nested inside a folder in order | |||
to keep their actual filename relative to that folder. | |||
An alternative would be to just clone a repo but this would require that the | |||
user always has git and git-lfs installed, and properly configured. | |||
Args: | |||
model_id (`str`): | |||
A user or an organization name and a repo name separated by a `/`. | |||
revision (`str`, *optional*): | |||
An optional Git revision id which can be a branch name, a tag, or a | |||
commit hash. NOTE: currently only branch and tag name is supported | |||
cache_dir (`str`, `Path`, *optional*): | |||
Path to the folder where cached files are stored. | |||
user_agent (`str`, `dict`, *optional*): | |||
The user-agent info in the form of a dictionary or a string. | |||
local_files_only (`bool`, *optional*, defaults to `False`): | |||
If `True`, avoid downloading the file and return the path to the | |||
local cached file if it exists. | |||
Returns: | |||
Local folder path (string) of repo snapshot | |||
<Tip> | |||
Raises the following errors: | |||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) | |||
if `use_auth_token=True` and the token cannot be found. | |||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if | |||
ETag cannot be determined. | |||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) | |||
if some parameter value is invalid | |||
</Tip> | |||
""" | |||
if cache_dir is None: | |||
cache_dir = get_cache_dir() | |||
if isinstance(cache_dir, Path): | |||
cache_dir = str(cache_dir) | |||
group_or_owner, name = model_id_to_group_owner_name(model_id) | |||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name) | |||
if local_files_only: | |||
if len(cache.cached_files) == 0: | |||
raise ValueError( | |||
'Cannot find the requested files in the cached path and outgoing' | |||
' traffic has been disabled. To enable model look-ups and downloads' | |||
" online, set 'local_files_only' to False.") | |||
logger.warn('We can not confirm the cached file is for revision: %s' | |||
% revision) | |||
return cache.get_root_location( | |||
) # we can not confirm the cached file is for snapshot 'revision' | |||
else: | |||
# make headers | |||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
_api = HubApi() | |||
# get file list from model repo | |||
branches, tags = _api.get_model_branches_and_tags(model_id) | |||
if revision not in branches and revision not in tags: | |||
raise NotExistError('The specified branch or tag : %s not exist!' | |||
% revision) | |||
model_files = _api.get_model_files( | |||
model_id=model_id, | |||
revision=revision, | |||
recursive=True, | |||
use_cookies=private) | |||
cookies = None | |||
if private: | |||
cookies = ModelScopeConfig.get_cookies() | |||
for model_file in model_files: | |||
if model_file['Type'] == 'tree': | |||
continue | |||
# check model_file is exist in cache, if exist, skip download, otherwise download | |||
if cache.exists(model_file): | |||
logger.info( | |||
'The specified file is in cache, skip downloading!') | |||
continue | |||
# get download url | |||
url = get_file_download_url( | |||
model_id=model_id, | |||
file_path=model_file['Path'], | |||
revision=revision) | |||
# First download to /tmp | |||
http_get_file( | |||
url=url, | |||
local_dir=tempfile.gettempdir(), | |||
file_name=model_file['Name'], | |||
headers=headers, | |||
cookies=None if cookies is None else cookies.get_dict()) | |||
# put file to cache | |||
cache.put_file( | |||
model_file, | |||
os.path.join(tempfile.gettempdir(), model_file['Name'])) | |||
return os.path.join(cache.get_root_location()) |
@@ -0,0 +1,40 @@ | |||
import subprocess | |||
from typing import List | |||
def run_subprocess(command: List[str], | |||
folder: str, | |||
check=True, | |||
**kwargs) -> subprocess.CompletedProcess: | |||
""" | |||
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, | |||
please call `subprocess.run` manually in case you would like for them not to | |||
be captured. | |||
Args: | |||
command (`List[str]`): | |||
The command to execute as a list of strings. | |||
folder (`str`): | |||
The folder in which to run the command. | |||
check (`bool`, *optional*, defaults to `True`): | |||
Setting `check` to `True` will raise a `subprocess.CalledProcessError` | |||
when the subprocess has a non-zero exit code. | |||
kwargs (`Dict[str]`): | |||
Keyword arguments to be passed to the `subprocess.run` underlying command. | |||
Returns: | |||
`subprocess.CompletedProcess`: The completed process. | |||
""" | |||
if isinstance(command, str): | |||
raise ValueError( | |||
'`run_subprocess` should be called with a list of strings.') | |||
return subprocess.run( | |||
command, | |||
stderr=subprocess.PIPE, | |||
stdout=subprocess.PIPE, | |||
check=check, | |||
encoding='utf-8', | |||
cwd=folder, | |||
**kwargs, | |||
) |
@@ -0,0 +1,294 @@ | |||
import hashlib | |||
import logging | |||
import os | |||
import pickle | |||
import tempfile | |||
import time | |||
from shutil import move, rmtree | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
class FileSystemCache(object): | |||
KEY_FILE_NAME = '.msc' | |||
"""Local file cache. | |||
""" | |||
def __init__( | |||
self, | |||
cache_root_location: str, | |||
**kwargs, | |||
): | |||
""" | |||
Parameters | |||
---------- | |||
cache_location: str | |||
The root location to store files. | |||
""" | |||
os.makedirs(cache_root_location, exist_ok=True) | |||
self.cache_root_location = cache_root_location | |||
self.load_cache() | |||
def get_root_location(self): | |||
return self.cache_root_location | |||
def load_cache(self): | |||
"""Read set of stored blocks from file | |||
Args: | |||
owner(`str`): individual or group username at modelscope, can be empty for official models | |||
name(`str`): name of the model | |||
Returns: | |||
The model details information. | |||
Raises: | |||
NotExistError: If the model is not exist, will throw NotExistError | |||
TODO: Error based error code. | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
self.cached_files = [] | |||
cache_keys_file_path = os.path.join(self.cache_root_location, | |||
FileSystemCache.KEY_FILE_NAME) | |||
if os.path.exists(cache_keys_file_path): | |||
with open(cache_keys_file_path, 'rb') as f: | |||
self.cached_files = pickle.load(f) | |||
def save_cached_files(self): | |||
"""Save cache metadata.""" | |||
# save new meta to tmp and move to KEY_FILE_NAME | |||
cache_keys_file_path = os.path.join(self.cache_root_location, | |||
FileSystemCache.KEY_FILE_NAME) | |||
# TODO: Sync file write | |||
fd, fn = tempfile.mkstemp() | |||
with open(fd, 'wb') as f: | |||
pickle.dump(self.cached_files, f) | |||
move(fn, cache_keys_file_path) | |||
def get_file(self, key): | |||
"""Check the key is in the cache, if exist, return the file, otherwise return None. | |||
Args: | |||
key(`str`): The cache key. | |||
Returns: | |||
If file exist, return the cached file location, otherwise None. | |||
Raises: | |||
None | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
pass | |||
def put_file(self, key, location): | |||
"""Put file to the cache, | |||
Args: | |||
key(`str`): The cache key | |||
location(`str`): Location of the file, we will move the file to cache. | |||
Returns: | |||
The cached file path of the file. | |||
Raises: | |||
None | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
pass | |||
def remove_key(self, key): | |||
"""Remove cache key in index, The file is removed manually | |||
Args: | |||
key (dict): The cache key. | |||
""" | |||
self.cached_files.remove(key) | |||
self.save_cached_files() | |||
def exists(self, key): | |||
for cache_file in self.cached_files: | |||
if cache_file == key: | |||
return True | |||
return False | |||
def clear_cache(self): | |||
"""Remove all files and metadat from the cache | |||
In the case of multiple cache locations, this clears only the last one, | |||
which is assumed to be the read/write one. | |||
""" | |||
rmtree(self.cache_root_location) | |||
self.load_cache() | |||
def hash_name(self, key): | |||
return hashlib.sha256(key.encode()).hexdigest() | |||
class ModelFileSystemCache(FileSystemCache): | |||
"""Local cache file layout | |||
cache_root/owner/model_name/|individual cached files | |||
|.mk: file, The cache index file | |||
Save only one version for each file. | |||
""" | |||
def __init__(self, cache_root, owner, name): | |||
"""Put file to the cache | |||
Args: | |||
cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/) | |||
owner(`str`): The model owner. | |||
name('str'): The name of the model | |||
branch('str'): The branch of model | |||
tag('str'): The tag of model | |||
Returns: | |||
Raises: | |||
None | |||
<Tip> | |||
model_id = {owner}/{name} | |||
</Tip> | |||
""" | |||
super().__init__(os.path.join(cache_root, owner, name)) | |||
def get_file_by_path(self, file_path): | |||
"""Retrieve the cache if there is file match the path. | |||
Args: | |||
file_path (str): The file path in the model. | |||
Returns: | |||
path: the full path of the file. | |||
""" | |||
for cached_file in self.cached_files: | |||
if file_path == cached_file['Path']: | |||
cached_file_path = os.path.join(self.cache_root_location, | |||
cached_file['Path']) | |||
if os.path.exists(cached_file_path): | |||
return cached_file_path | |||
else: | |||
self.remove_key(cached_file) | |||
return None | |||
def get_file_by_path_and_commit_id(self, file_path, commit_id): | |||
"""Retrieve the cache if there is file match the path. | |||
Args: | |||
file_path (str): The file path in the model. | |||
commit_id (str): The commit id of the file | |||
Returns: | |||
path: the full path of the file. | |||
""" | |||
for cached_file in self.cached_files: | |||
if file_path == cached_file['Path'] and \ | |||
(cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])): | |||
cached_file_path = os.path.join(self.cache_root_location, | |||
cached_file['Path']) | |||
if os.path.exists(cached_file_path): | |||
return cached_file_path | |||
else: | |||
self.remove_key(cached_file) | |||
return None | |||
def get_file_by_info(self, model_file_info): | |||
"""Check if exist cache file. | |||
Args: | |||
model_file_info (ModelFileInfo): The file information of the file. | |||
Returns: | |||
_type_: _description_ | |||
""" | |||
cache_key = self.__get_cache_key(model_file_info) | |||
for cached_file in self.cached_files: | |||
if cached_file == cache_key: | |||
orig_path = os.path.join(self.cache_root_location, | |||
cached_file['Path']) | |||
if os.path.exists(orig_path): | |||
return orig_path | |||
else: | |||
self.remove_key(cached_file) | |||
return None | |||
def __get_cache_key(self, model_file_info): | |||
cache_key = { | |||
'Path': model_file_info['Path'], | |||
'Revision': model_file_info['Revision'], # commit id | |||
} | |||
return cache_key | |||
def exists(self, model_file_info): | |||
"""Check the file is cached or not. | |||
Args: | |||
model_file_info (CachedFileInfo): The cached file info | |||
Returns: | |||
bool: If exists return True otherwise False | |||
""" | |||
key = self.__get_cache_key(model_file_info) | |||
is_exists = False | |||
for cached_key in self.cached_files: | |||
if cached_key['Path'] == key['Path'] and ( | |||
cached_key['Revision'].startswith(key['Revision']) | |||
or key['Revision'].startswith(cached_key['Revision'])): | |||
is_exists = True | |||
file_path = os.path.join(self.cache_root_location, | |||
model_file_info['Path']) | |||
if is_exists: | |||
if os.path.exists(file_path): | |||
return True | |||
else: | |||
self.remove_key( | |||
model_file_info) # sameone may manual delete the file | |||
return False | |||
def remove_if_exists(self, model_file_info): | |||
"""We in cache, remove it. | |||
Args: | |||
model_file_info (ModelFileInfo): The model file information from server. | |||
""" | |||
for cached_file in self.cached_files: | |||
if cached_file['Path'] == model_file_info['Path']: | |||
self.remove_key(cached_file) | |||
file_path = os.path.join(self.cache_root_location, | |||
cached_file['Path']) | |||
if os.path.exists(file_path): | |||
os.remove(file_path) | |||
def put_file(self, model_file_info, model_file_location): | |||
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache. | |||
Args: | |||
model_file_info (str): The file description returned by get_model_files | |||
sample: | |||
{ | |||
"CommitMessage": "add model\n", | |||
"CommittedDate": 1654857567, | |||
"CommitterName": "mulin.lyh", | |||
"IsLFS": false, | |||
"Mode": "100644", | |||
"Name": "resnet18.pth", | |||
"Path": "resnet18.pth", | |||
"Revision": "09b68012b27de0048ba74003690a890af7aff192", | |||
"Size": 46827520, | |||
"Type": "blob" | |||
} | |||
model_file_location (str): The location of the temporary file. | |||
Raises: | |||
NotImplementedError: _description_ | |||
Returns: | |||
str: The location of the cached file. | |||
""" | |||
self.remove_if_exists(model_file_info) # backup old revision | |||
cache_key = self.__get_cache_key(model_file_info) | |||
cache_full_path = os.path.join( | |||
self.cache_root_location, | |||
cache_key['Path']) # Branch and Tag do not have same name. | |||
cache_file_dir = os.path.dirname(cache_full_path) | |||
if not os.path.exists(cache_file_dir): | |||
os.makedirs(cache_file_dir, exist_ok=True) | |||
# We can't make operation transaction | |||
move(model_file_location, cache_full_path) | |||
self.cached_files.append(cache_key) | |||
self.save_cached_files() | |||
return cache_full_path |
@@ -0,0 +1,39 @@ | |||
import os | |||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN, | |||
DEFAULT_MODELSCOPE_GROUP, | |||
MODEL_ID_SEPARATOR, | |||
MODELSCOPE_URL_SCHEME) | |||
def model_id_to_group_owner_name(model_id): | |||
if MODEL_ID_SEPARATOR in model_id: | |||
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0] | |||
name = model_id.split(MODEL_ID_SEPARATOR)[1] | |||
else: | |||
group_or_owner = DEFAULT_MODELSCOPE_GROUP | |||
name = model_id | |||
return group_or_owner, name | |||
def get_cache_dir(): | |||
""" | |||
cache dir precedence: | |||
function parameter > enviroment > ~/.cache/modelscope/hub | |||
""" | |||
default_cache_dir = os.path.expanduser( | |||
os.path.join('~/.cache', 'modelscope')) | |||
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir, | |||
'hub')) | |||
def get_endpoint(): | |||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||
DEFAULT_MODELSCOPE_DOMAIN) | |||
return MODELSCOPE_URL_SCHEME + modelscope_domain | |||
def get_gitlab_domain(): | |||
return os.getenv('MODELSCOPE_GITLAB_DOMAIN', | |||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN) |
@@ -4,12 +4,10 @@ import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from typing import Dict, Union | |||
from maas_hub.snapshot_download import snapshot_download | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.builder import build_model | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.hub import get_model_cache_dir | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
@@ -47,9 +45,7 @@ class Model(ABC): | |||
if osp.exists(model_name_or_path): | |||
local_model_dir = model_name_or_path | |||
else: | |||
cache_path = get_model_cache_dir(model_name_or_path) | |||
local_model_dir = cache_path if osp.exists( | |||
cache_path) else snapshot_download(model_name_or_path) | |||
local_model_dir = snapshot_download(model_name_or_path) | |||
# else: | |||
# raise ValueError( | |||
# 'Remote model repo {model_name_or_path} does not exists') | |||
@@ -4,13 +4,11 @@ import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Dict, Generator, List, Union | |||
from maas_hub.snapshot_download import snapshot_download | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.base import Model | |||
from modelscope.preprocessors import Preprocessor | |||
from modelscope.pydatasets import PyDataset | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.logger import get_logger | |||
from .outputs import TASK_OUTPUTS | |||
from .util import is_model_name | |||
@@ -32,9 +30,7 @@ class Pipeline(ABC): | |||
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model | |||
if isinstance(model, str) and model.startswith('damo/'): | |||
if not osp.exists(model): | |||
cache_path = get_model_cache_dir(model) | |||
model = cache_path if osp.exists( | |||
cache_path) else snapshot_download(model) | |||
model = snapshot_download(model) | |||
return Model.from_pretrained(model) if is_model_name( | |||
model) else model | |||
elif isinstance(model, Model): | |||
@@ -2,8 +2,7 @@ | |||
import os.path as osp | |||
from typing import List, Union | |||
from maas_hub.file_download import model_file_download | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
@@ -4,11 +4,10 @@ from typing import Any, Dict, Union | |||
import numpy as np | |||
import torch | |||
from maas_hub.snapshot_download import snapshot_download | |||
from PIL import Image | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.utils.constant import Fields, ModelFile | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.type_assert import type_assert | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
@@ -34,9 +33,7 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||
if osp.exists(model_dir): | |||
local_model_dir = model_dir | |||
else: | |||
cache_path = get_model_cache_dir(model_dir) | |||
local_model_dir = cache_path if osp.exists( | |||
cache_path) else snapshot_download(model_dir) | |||
local_model_dir = snapshot_download(model_dir) | |||
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) | |||
bpe_dir = local_model_dir | |||
@@ -2,13 +2,10 @@ | |||
import os | |||
from maas_hub.constants import MODEL_ID_SEPARATOR | |||
from modelscope.hub.constants import MODEL_ID_SEPARATOR | |||
from modelscope.hub.utils.utils import get_cache_dir | |||
# temp solution before the hub-cache is in place | |||
def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||
model_id_expanded = model_id.replace('/', | |||
MODEL_ID_SEPARATOR) + '.' + branch | |||
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||
return os.getenv('MAAS_CACHE', | |||
os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||
def get_model_cache_dir(model_id: str): | |||
return os.path.join(get_cache_dir(), model_id) |
@@ -1,13 +1,16 @@ | |||
addict | |||
datasets | |||
easydict | |||
https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl | |||
filelock>=3.3.0 | |||
numpy | |||
opencv-python-headless | |||
Pillow>=6.2.0 | |||
pyyaml | |||
requests | |||
requests==2.27.1 | |||
scipy | |||
setuptools==58.0.4 | |||
tokenizers<=0.10.3 | |||
tqdm>=4.64.0 | |||
transformers<=4.16.2 | |||
yapf |
@@ -0,0 +1,157 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import subprocess | |||
import tempfile | |||
import unittest | |||
import uuid | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.repository import Repository | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.hub.utils.utils import get_gitlab_domain | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
model_chinese_name = '达摩卡通化模型' | |||
model_org = 'unittest' | |||
DEFAULT_GIT_PATH = 'git' | |||
class GitError(Exception): | |||
pass | |||
# TODO make thest git operation to git library after merge code. | |||
def run_git_command(git_path, *args) -> subprocess.CompletedProcess: | |||
response = subprocess.run([git_path, *args], capture_output=True) | |||
try: | |||
response.check_returncode() | |||
return response.stdout.decode('utf8') | |||
except subprocess.CalledProcessError as error: | |||
raise GitError(error.stderr.decode('utf8')) | |||
# for public project, token can None, private repo, there must token. | |||
def clone(local_dir: str, token: str, url: str): | |||
url = url.replace('//', '//oauth2:%s@' % token) | |||
clone_args = '-C %s clone %s' % (local_dir, url) | |||
clone_args = clone_args.split(' ') | |||
stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) | |||
print('stdout: %s' % stdout) | |||
def push(local_dir: str, token: str, url: str): | |||
url = url.replace('//', '//oauth2:%s@' % token) | |||
push_args = '-C %s push %s' % (local_dir, url) | |||
push_args = push_args.split(' ') | |||
stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) | |||
print('stdout: %s' % stdout) | |||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||
download_model_file_name = 'mnist-12.onnx' | |||
class HubOperationTest(unittest.TestCase): | |||
def setUp(self): | |||
self.old_cwd = os.getcwd() | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.api.login(USER_NAME, PASSWORD) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (model_org, self.model_name) | |||
self.api.create_model( | |||
model_id=self.model_id, | |||
chinese_name=model_chinese_name, | |||
visibility=5, # 1-private, 5-public | |||
license='apache-2.0') | |||
def tearDown(self): | |||
os.chdir(self.old_cwd) | |||
self.api.delete_model(model_id=self.model_id) | |||
def test_model_repo_creation(self): | |||
# change to proper model names before use | |||
try: | |||
info = self.api.get_model(model_id=self.model_id) | |||
assert info['Name'] == self.model_name | |||
except KeyError as ke: | |||
if ke.args[0] == 'name': | |||
print(f'model {self.model_name} already exists, ignore') | |||
else: | |||
raise | |||
# Note that this can be done via git operation once model repo | |||
# has been created. Git-Op is the RECOMMENDED model upload approach | |||
def test_model_upload(self): | |||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
print(url) | |||
temporary_dir = tempfile.mkdtemp() | |||
os.chdir(temporary_dir) | |||
cmd_args = 'clone %s' % url | |||
cmd_args = cmd_args.split(' ') | |||
out = run_git_command('git', *cmd_args) | |||
print(out) | |||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||
os.chdir(repo_dir) | |||
os.system('touch file1') | |||
os.system('git add file1') | |||
os.system("git commit -m 'Test'") | |||
token = ModelScopeConfig.get_token() | |||
push(repo_dir, token, url) | |||
def test_download_single_file(self): | |||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
print(url) | |||
temporary_dir = tempfile.mkdtemp() | |||
os.chdir(temporary_dir) | |||
os.system('git clone %s' % url) | |||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||
os.chdir(repo_dir) | |||
os.system('wget %s' % sample_model_url) | |||
os.system('git add .') | |||
os.system("git commit -m 'Add file'") | |||
token = ModelScopeConfig.get_token() | |||
push(repo_dir, token, url) | |||
assert os.path.exists( | |||
os.path.join(temporary_dir, self.model_name, | |||
download_model_file_name)) | |||
downloaded_file = model_file_download( | |||
model_id=self.model_id, file_path=download_model_file_name) | |||
mdtime1 = os.path.getmtime(downloaded_file) | |||
# download again | |||
downloaded_file = model_file_download( | |||
model_id=self.model_id, file_path=download_model_file_name) | |||
mdtime2 = os.path.getmtime(downloaded_file) | |||
assert mdtime1 == mdtime2 | |||
def test_snapshot_download(self): | |||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
print(url) | |||
temporary_dir = tempfile.mkdtemp() | |||
os.chdir(temporary_dir) | |||
os.system('git clone %s' % url) | |||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||
os.chdir(repo_dir) | |||
os.system('wget %s' % sample_model_url) | |||
os.system('git add .') | |||
os.system("git commit -m 'Add file'") | |||
token = ModelScopeConfig.get_token() | |||
push(repo_dir, token, url) | |||
snapshot_path = snapshot_download(model_id=self.model_id) | |||
downloaded_file_path = os.path.join(snapshot_path, | |||
download_model_file_name) | |||
assert os.path.exists(downloaded_file_path) | |||
mdtime1 = os.path.getmtime(downloaded_file_path) | |||
# download again | |||
snapshot_path = snapshot_download(model_id=self.model_id) | |||
mdtime2 = os.path.getmtime(downloaded_file_path) | |||
assert mdtime1 == mdtime2 | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -10,7 +10,6 @@ from modelscope.fileio import File | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pydatasets import PyDataset | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.test_utils import test_level | |||
@@ -18,11 +17,6 @@ class ImageMattingTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_unet_image-matting' | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||
@unittest.skip('deprecated, download model from model hub instead') | |||
def test_run_with_direct_file_download(self): | |||
@@ -27,7 +27,7 @@ class OCRDetectionTest(unittest.TestCase): | |||
print('ocr detection results: ') | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
ocr_detection = pipeline(Tasks.ocr_detection) | |||
self.pipeline_inference(ocr_detection, self.test_image) | |||
@@ -2,14 +2,12 @@ | |||
import shutil | |||
import unittest | |||
from maas_hub.snapshot_download import snapshot_download | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import SbertForSentenceSimilarity | |||
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline | |||
from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.test_utils import test_level | |||
@@ -18,13 +16,6 @@ class SentenceSimilarityTest(unittest.TestCase): | |||
sentence1 = '今天气温比昨天高么?' | |||
sentence2 = '今天湿度比昨天高么?' | |||
def setUp(self) -> None: | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run(self): | |||
cache_path = snapshot_download(self.model_id) | |||
@@ -5,7 +5,6 @@ import unittest | |||
from modelscope.fileio import File | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import get_model_cache_dir | |||
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' | |||
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' | |||
@@ -30,11 +29,6 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||
# A temporary hack to provide c++ lib. Download it first. | |||
download(AEC_LIB_URL, AEC_LIB_FILE) | |||
@@ -11,7 +11,6 @@ from modelscope.pipelines import SequenceClassificationPipeline, pipeline | |||
from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
from modelscope.pydatasets import PyDataset | |||
from modelscope.utils.constant import Hubs, Tasks | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.test_utils import test_level | |||
@@ -19,11 +18,6 @@ class SequenceClassificationTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/bert-base-sst2' | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||
def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
from easynlp.appzoo import load_dataset | |||
@@ -1,8 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from maas_hub.snapshot_download import snapshot_download | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import PalmForTextGeneration | |||
from modelscope.pipelines import TextGenerationPipeline, pipeline | |||
@@ -2,14 +2,12 @@ | |||
import shutil | |||
import unittest | |||
from maas_hub.snapshot_download import snapshot_download | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import StructBertForTokenClassification | |||
from modelscope.pipelines import WordSegmentationPipeline, pipeline | |||
from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.test_utils import test_level | |||
@@ -17,13 +15,6 @@ class WordSegmentationTest(unittest.TestCase): | |||
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' | |||
sentence = '今天天气不错,适合出去游玩' | |||
def setUp(self) -> None: | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
cache_path = snapshot_download(self.model_id) | |||
@@ -61,7 +61,7 @@ if __name__ == '__main__': | |||
parser.add_argument( | |||
'--test_dir', default='tests', help='directory to be tested') | |||
parser.add_argument( | |||
'--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0') | |||
'--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0') | |||
args = parser.parse_args() | |||
set_test_level(args.level) | |||
logger.info(f'TEST LEVEL: {test_level()}') | |||
@@ -1,50 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import unittest | |||
from maas_hub.maas_api import MaasApi | |||
from maas_hub.repository import Repository | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
class HubOperationTest(unittest.TestCase): | |||
def setUp(self): | |||
self.api = MaasApi() | |||
# note this is temporary before official account management is ready | |||
self.api.login(USER_NAME, PASSWORD) | |||
@unittest.skip('to be used for local test only') | |||
def test_model_repo_creation(self): | |||
# change to proper model names before use | |||
model_name = 'cv_unet_person-image-cartoon_compound-models' | |||
model_chinese_name = '达摩卡通化模型' | |||
model_org = 'damo' | |||
try: | |||
self.api.create_model( | |||
owner=model_org, | |||
name=model_name, | |||
chinese_name=model_chinese_name, | |||
visibility=5, # 1-private, 5-public | |||
license='apache-2.0') | |||
# TODO: support proper name duplication checking | |||
except KeyError as ke: | |||
if ke.args[0] == 'name': | |||
print(f'model {self.model_name} already exists, ignore') | |||
else: | |||
raise | |||
# Note that this can be done via git operation once model repo | |||
# has been created. Git-Op is the RECOMMENDED model upload approach | |||
@unittest.skip('to be used for local test only') | |||
def test_model_upload(self): | |||
local_path = '/path/to/local/model/directory' | |||
assert osp.exists(local_path), 'Local model directory not exist.' | |||
repo = Repository(local_dir=local_path) | |||
repo.push_to_hub(commit_message='Upload model files') | |||
if __name__ == '__main__': | |||
unittest.main() |