@@ -9,9 +9,10 @@ from typing import List, Optional, Tuple, Union | |||||
import requests | import requests | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .constants import LOGGER_NAME | |||||
from .constants import MODELSCOPE_URL_SCHEME | |||||
from .errors import NotExistError, is_ok, raise_on_error | from .errors import NotExistError, is_ok, raise_on_error | ||||
from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||||
from .utils.utils import (get_endpoint, get_gitlab_domain, | |||||
model_id_to_group_owner_name) | |||||
logger = get_logger() | logger = get_logger() | ||||
@@ -40,9 +41,6 @@ class HubApi: | |||||
<Tip> | <Tip> | ||||
You only have to login once within 30 days. | You only have to login once within 30 days. | ||||
</Tip> | </Tip> | ||||
TODO: handle cookies expire | |||||
""" | """ | ||||
path = f'{self.endpoint}/api/v1/login' | path = f'{self.endpoint}/api/v1/login' | ||||
r = requests.post( | r = requests.post( | ||||
@@ -94,14 +92,14 @@ class HubApi: | |||||
'Path': owner_or_group, | 'Path': owner_or_group, | ||||
'Name': name, | 'Name': name, | ||||
'ChineseName': chinese_name, | 'ChineseName': chinese_name, | ||||
'Visibility': visibility, | |||||
'Visibility': visibility, # server check | |||||
'License': license | 'License': license | ||||
}, | }, | ||||
cookies=cookies) | cookies=cookies) | ||||
r.raise_for_status() | r.raise_for_status() | ||||
raise_on_error(r.json()) | raise_on_error(r.json()) | ||||
d = r.json() | |||||
return d['Data']['Name'] | |||||
model_repo_url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||||
return model_repo_url | |||||
def delete_model(self, model_id): | def delete_model(self, model_id): | ||||
"""_summary_ | """_summary_ | ||||
@@ -209,25 +207,37 @@ class HubApi: | |||||
class ModelScopeConfig: | class ModelScopeConfig: | ||||
path_credential = expanduser('~/.modelscope/credentials') | path_credential = expanduser('~/.modelscope/credentials') | ||||
os.makedirs(path_credential, exist_ok=True) | |||||
@classmethod | |||||
def make_sure_credential_path_exist(cls): | |||||
os.makedirs(cls.path_credential, exist_ok=True) | |||||
@classmethod | @classmethod | ||||
def save_cookies(cls, cookies: CookieJar): | def save_cookies(cls, cookies: CookieJar): | ||||
cls.make_sure_credential_path_exist() | |||||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | ||||
pickle.dump(cookies, f) | pickle.dump(cookies, f) | ||||
@classmethod | @classmethod | ||||
def get_cookies(cls): | def get_cookies(cls): | ||||
try: | try: | ||||
with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: | |||||
return pickle.load(f) | |||||
cookies_path = os.path.join(cls.path_credential, 'cookies') | |||||
with open(cookies_path, 'rb') as f: | |||||
cookies = pickle.load(f) | |||||
for cookie in cookies: | |||||
if cookie.is_expired(): | |||||
logger.warn('Auth is expored, please re-login') | |||||
return None | |||||
return cookies | |||||
except FileNotFoundError: | except FileNotFoundError: | ||||
logger.warn("Auth token does not exist, you'll get authentication \ | |||||
error when downloading private model files. Please login first" | |||||
) | |||||
logger.warn( | |||||
"Auth token does not exist, you'll get authentication error when downloading \ | |||||
private model files. Please login first") | |||||
return None | |||||
@classmethod | @classmethod | ||||
def save_token(cls, token: str): | def save_token(cls, token: str): | ||||
cls.make_sure_credential_path_exist() | |||||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | ||||
f.write(token) | f.write(token) | ||||
@@ -6,6 +6,10 @@ class RequestError(Exception): | |||||
pass | pass | ||||
class GitError(Exception): | |||||
pass | |||||
def is_ok(rsp): | def is_ok(rsp): | ||||
""" Check the request is ok | """ Check the request is ok | ||||
@@ -1,82 +1,161 @@ | |||||
from threading import local | |||||
from tkinter.messagebox import NO | |||||
from typing import Union | |||||
import subprocess | |||||
from typing import List | |||||
from xmlrpc.client import Boolean | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .constants import LOGGER_NAME | |||||
from .utils._subprocess import run_subprocess | |||||
from .errors import GitError | |||||
logger = get_logger | |||||
logger = get_logger() | |||||
def git_clone( | |||||
local_dir: str, | |||||
repo_url: str, | |||||
): | |||||
# TODO: use "git clone" or "git lfs clone" according to git version | |||||
# TODO: print stderr when subprocess fails | |||||
run_subprocess( | |||||
f'git clone {repo_url}'.split(), | |||||
local_dir, | |||||
True, | |||||
) | |||||
class Singleton(type): | |||||
_instances = {} | |||||
def __call__(cls, *args, **kwargs): | |||||
if cls not in cls._instances: | |||||
cls._instances[cls] = super(Singleton, | |||||
cls).__call__(*args, **kwargs) | |||||
return cls._instances[cls] | |||||
def git_checkout( | |||||
local_dir: str, | |||||
revsion: str, | |||||
): | |||||
run_subprocess(f'git checkout {revsion}'.split(), local_dir) | |||||
def git_add(local_dir: str, ): | |||||
run_subprocess( | |||||
'git add .'.split(), | |||||
local_dir, | |||||
True, | |||||
) | |||||
def git_commit(local_dir: str, commit_message: str): | |||||
run_subprocess( | |||||
'git commit -v -m'.split() + [commit_message], | |||||
local_dir, | |||||
True, | |||||
) | |||||
def git_push(local_dir: str, branch: str): | |||||
# check current branch | |||||
cur_branch = git_current_branch(local_dir) | |||||
if cur_branch != branch: | |||||
logger.error( | |||||
"You're trying to push to a different branch, please double check") | |||||
return | |||||
run_subprocess( | |||||
f'git push origin {branch}'.split(), | |||||
local_dir, | |||||
True, | |||||
) | |||||
def git_current_branch(local_dir: str) -> Union[str, None]: | |||||
""" | |||||
Get current branch name | |||||
Args: | |||||
local_dir(`str`): local model repo directory | |||||
Returns | |||||
branch name you're currently on | |||||
class GitCommandWrapper(metaclass=Singleton): | |||||
"""Some git operation wrapper | |||||
""" | """ | ||||
try: | |||||
process = run_subprocess( | |||||
'git rev-parse --abbrev-ref HEAD'.split(), | |||||
local_dir, | |||||
True, | |||||
) | |||||
return str(process.stdout).strip() | |||||
except Exception as e: | |||||
raise e | |||||
default_git_path = 'git' # The default git command line | |||||
def __init__(self, path: str = None): | |||||
self.git_path = path or self.default_git_path | |||||
def _run_git_command(self, *args) -> subprocess.CompletedProcess: | |||||
"""Run git command, if command return 0, return subprocess.response | |||||
otherwise raise GitError, message is stdout and stderr. | |||||
Raises: | |||||
GitError: Exception with stdout and stderr. | |||||
Returns: | |||||
subprocess.CompletedProcess: the command response | |||||
""" | |||||
logger.info(' '.join(args)) | |||||
response = subprocess.run( | |||||
[self.git_path, *args], | |||||
stdout=subprocess.PIPE, | |||||
stderr=subprocess.PIPE) # compatible for python3.6 | |||||
try: | |||||
response.check_returncode() | |||||
return response | |||||
except subprocess.CalledProcessError as error: | |||||
raise GitError( | |||||
'stdout: %s, stderr: %s' % | |||||
(response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | |||||
def _add_token(self, token: str, url: str): | |||||
if token: | |||||
if '//oauth2' not in url: | |||||
url = url.replace('//', '//oauth2:%s@' % token) | |||||
return url | |||||
def remove_token_from_url(self, url: str): | |||||
if url and '//oauth2' in url: | |||||
start_index = url.find('oauth2') | |||||
end_index = url.find('@') | |||||
url = url[:start_index] + url[end_index + 1:] | |||||
return url | |||||
def is_lfs_installed(self): | |||||
cmd = ['lfs', 'env'] | |||||
try: | |||||
self._run_git_command(*cmd) | |||||
return True | |||||
except GitError: | |||||
return False | |||||
def clone(self, | |||||
repo_base_dir: str, | |||||
token: str, | |||||
url: str, | |||||
repo_name: str, | |||||
branch: str = None): | |||||
""" git clone command wrapper. | |||||
For public project, token can None, private repo, there must token. | |||||
Args: | |||||
repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name | |||||
token (str): The git token, must be provided for private project. | |||||
url (str): The remote url | |||||
repo_name (str): The local repository path name. | |||||
branch (str, optional): _description_. Defaults to None. | |||||
""" | |||||
url = self._add_token(token, url) | |||||
if branch: | |||||
clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url, | |||||
repo_name, branch) | |||||
else: | |||||
clone_args = '-C %s clone %s' % (repo_base_dir, url) | |||||
logger.debug(clone_args) | |||||
clone_args = clone_args.split(' ') | |||||
response = self._run_git_command(*clone_args) | |||||
logger.info(response.stdout.decode('utf8')) | |||||
return response | |||||
def add(self, | |||||
repo_dir: str, | |||||
files: List[str] = list(), | |||||
all_files: bool = False): | |||||
if all_files: | |||||
add_args = '-C %s add -A' % repo_dir | |||||
elif len(files) > 0: | |||||
files_str = ' '.join(files) | |||||
add_args = '-C %s add %s' % (repo_dir, files_str) | |||||
add_args = add_args.split(' ') | |||||
rsp = self._run_git_command(*add_args) | |||||
logger.info(rsp.stdout.decode('utf8')) | |||||
return rsp | |||||
def commit(self, repo_dir: str, message: str): | |||||
"""Run git commit command | |||||
Args: | |||||
message (str): commit message. | |||||
""" | |||||
commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message] | |||||
rsp = self._run_git_command(*commit_args) | |||||
logger.info(rsp.stdout.decode('utf8')) | |||||
return rsp | |||||
def checkout(self, repo_dir: str, revision: str): | |||||
cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision] | |||||
return self._run_git_command(*cmds) | |||||
def new_branch(self, repo_dir: str, revision: str): | |||||
cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | |||||
return self._run_git_command(*cmds) | |||||
def pull(self, repo_dir: str): | |||||
cmds = ['-C', repo_dir, 'pull'] | |||||
return self._run_git_command(*cmds) | |||||
def push(self, | |||||
repo_dir: str, | |||||
token: str, | |||||
url: str, | |||||
local_branch: str, | |||||
remote_branch: str, | |||||
force: bool = False): | |||||
url = self._add_token(token, url) | |||||
push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch, | |||||
remote_branch) | |||||
if force: | |||||
push_args += ' -f' | |||||
push_args = push_args.split(' ') | |||||
rsp = self._run_git_command(*push_args) | |||||
logger.info(rsp.stdout.decode('utf8')) | |||||
return rsp | |||||
def get_repo_remote_url(self, repo_dir: str): | |||||
cmd_args = '-C %s config --get remote.origin.url' % repo_dir | |||||
cmd_args = cmd_args.split(' ') | |||||
rsp = self._run_git_command(*cmd_args) | |||||
url = rsp.stdout.decode('utf8') | |||||
return url.strip() |
@@ -1,173 +1,97 @@ | |||||
import os | import os | ||||
import subprocess | |||||
from pathlib import Path | |||||
from typing import Optional, Union | |||||
from typing import List, Optional | |||||
from modelscope.hub.errors import GitError | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .api import ModelScopeConfig | from .api import ModelScopeConfig | ||||
from .constants import MODELSCOPE_URL_SCHEME | from .constants import MODELSCOPE_URL_SCHEME | ||||
from .git import git_add, git_checkout, git_clone, git_commit, git_push | |||||
from .utils._subprocess import run_subprocess | |||||
from .git import GitCommandWrapper | |||||
from .utils.utils import get_gitlab_domain | from .utils.utils import get_gitlab_domain | ||||
logger = get_logger() | logger = get_logger() | ||||
class Repository: | class Repository: | ||||
"""Representation local model git repository. | |||||
""" | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
local_dir: str, | |||||
clone_from: Optional[str] = None, | |||||
auth_token: Optional[str] = None, | |||||
private: Optional[bool] = False, | |||||
model_dir: str, | |||||
clone_from: str, | |||||
revision: Optional[str] = 'master', | revision: Optional[str] = 'master', | ||||
auth_token: Optional[str] = None, | |||||
git_path: Optional[str] = None, | |||||
): | ): | ||||
""" | """ | ||||
Instantiate a Repository object by cloning the remote ModelScopeHub repo | Instantiate a Repository object by cloning the remote ModelScopeHub repo | ||||
Args: | Args: | ||||
local_dir(`str`): | |||||
local directory to store the model files | |||||
clone_from(`Optional[str] = None`): | |||||
model_dir(`str`): | |||||
The model root directory. | |||||
clone_from: | |||||
model id in ModelScope-hub from which git clone | model id in ModelScope-hub from which git clone | ||||
You should ignore this parameter when `local_dir` is already a git repo | |||||
auth_token(`Optional[str]`): | |||||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||||
as the token is already saved when you login the first time | |||||
private(`Optional[bool]`): | |||||
whether the model is private, default to False | |||||
revision(`Optional[str]`): | revision(`Optional[str]`): | ||||
revision of the model you want to clone from. Can be any of a branch, tag or commit hash | revision of the model you want to clone from. Can be any of a branch, tag or commit hash | ||||
auth_token(`Optional[str]`): | |||||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||||
as the token is already saved when you login the first time, if None, we will use saved token. | |||||
git_path:(`Optional[str]`): | |||||
The git command line path, if None, we use 'git' | |||||
""" | """ | ||||
logger.info('Instantiating Repository object...') | |||||
# Create local directory if not exist | |||||
os.makedirs(local_dir, exist_ok=True) | |||||
self.local_dir = os.path.join(os.getcwd(), local_dir) | |||||
self.private = private | |||||
# Check git and git-lfs installation | |||||
self.check_git_versions() | |||||
# Retrieve auth token | |||||
if not private and isinstance(auth_token, str): | |||||
logger.warning( | |||||
'cloning a public repo with a token, which will be ignored') | |||||
self.token = None | |||||
self.model_dir = model_dir | |||||
self.model_base_dir = os.path.dirname(model_dir) | |||||
self.model_repo_name = os.path.basename(model_dir) | |||||
if auth_token: | |||||
self.auth_token = auth_token | |||||
else: | else: | ||||
if isinstance(auth_token, str): | |||||
self.token = auth_token | |||||
else: | |||||
self.token = ModelScopeConfig.get_token() | |||||
if self.token is None: | |||||
raise EnvironmentError( | |||||
'Token does not exist, the clone will fail for private repo.' | |||||
'Please login first.') | |||||
# git clone | |||||
if clone_from is not None: | |||||
self.model_id = clone_from | |||||
logger.info('cloning model repo to %s ...', self.local_dir) | |||||
git_clone(self.local_dir, self.get_repo_url()) | |||||
else: | |||||
if is_git_repo(self.local_dir): | |||||
logger.debug('[Repository] is a valid git repo') | |||||
else: | |||||
raise ValueError( | |||||
'If not specifying `clone_from`, you need to pass Repository a' | |||||
' valid git clone.') | |||||
# git checkout | |||||
if isinstance(revision, str) and revision != 'master': | |||||
git_checkout(revision) | |||||
def push_to_hub(self, | |||||
commit_message: str, | |||||
revision: Optional[str] = 'master'): | |||||
""" | |||||
Push changes changes to hub | |||||
Args: | |||||
commit_message(`str`): | |||||
commit message describing the changes, it's mandatory | |||||
revision(`Optional[str]`): | |||||
remote branch you want to push to, default to `master` | |||||
<Tip> | |||||
The function complains when local and remote branch are different, please be careful | |||||
</Tip> | |||||
""" | |||||
git_add(self.local_dir) | |||||
git_commit(self.local_dir, commit_message) | |||||
logger.info('Pushing changes to repo...') | |||||
git_push(self.local_dir, revision) | |||||
# TODO: if git push fails, how to retry? | |||||
def check_git_versions(self): | |||||
""" | |||||
Checks that `git` and `git-lfs` can be run. | |||||
Raises: | |||||
`EnvironmentError`: if `git` or `git-lfs` are not installed. | |||||
""" | |||||
try: | |||||
git_version = run_subprocess('git --version'.split(), | |||||
self.local_dir).stdout.strip() | |||||
except FileNotFoundError: | |||||
raise EnvironmentError( | |||||
'Looks like you do not have git installed, please install.') | |||||
self.auth_token = ModelScopeConfig.get_token() | |||||
git_wrapper = GitCommandWrapper() | |||||
if not git_wrapper.is_lfs_installed(): | |||||
logger.error('git lfs is not installed, please install.') | |||||
self.git_wrapper = GitCommandWrapper(git_path) | |||||
os.makedirs(self.model_dir, exist_ok=True) | |||||
url = self._get_model_id_url(clone_from) | |||||
if os.listdir(self.model_dir): # directory not empty. | |||||
remote_url = self._get_remote_url() | |||||
remote_url = self.git_wrapper.remove_token_from_url(remote_url) | |||||
if remote_url and remote_url == url: # need not clone again | |||||
return | |||||
self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, | |||||
self.model_repo_name, revision) | |||||
def _get_model_id_url(self, model_id): | |||||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||||
return url | |||||
def _get_remote_url(self): | |||||
try: | try: | ||||
lfs_version = run_subprocess('git-lfs --version'.split(), | |||||
self.local_dir).stdout.strip() | |||||
except FileNotFoundError: | |||||
raise EnvironmentError( | |||||
'Looks like you do not have git-lfs installed, please install.' | |||||
' You can install from https://git-lfs.github.com/.' | |||||
' Then run `git lfs install` (you only have to do this once).') | |||||
logger.info(git_version + '\n' + lfs_version) | |||||
def get_repo_url(self) -> str: | |||||
""" | |||||
Get repo url to clone, according whether the repo is private or not | |||||
remote = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||||
except GitError: | |||||
remote = None | |||||
return remote | |||||
def push(self, | |||||
commit_message: str, | |||||
files: List[str] = list(), | |||||
all_files: bool = False, | |||||
branch: Optional[str] = 'master', | |||||
force: bool = False): | |||||
"""Push local to remote, this method will do. | |||||
git add | |||||
git commit | |||||
git push | |||||
Args: | |||||
commit_message (str): commit message | |||||
revision (Optional[str], optional): which branch to push. Defaults to 'master'. | |||||
""" | """ | ||||
url = None | |||||
if self.private: | |||||
url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' | |||||
else: | |||||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' | |||||
if not url: | |||||
raise ValueError( | |||||
'Empty repo url, please check clone_from parameter') | |||||
logger.debug('url to clone: %s', str(url)) | |||||
return url | |||||
def is_git_repo(folder: Union[str, Path]) -> bool: | |||||
""" | |||||
Check if the folder is the root or part of a git repository | |||||
Args: | |||||
folder (`str`): | |||||
The folder in which to run the command. | |||||
Returns: | |||||
`bool`: `True` if the repository is part of a repository, `False` | |||||
otherwise. | |||||
""" | |||||
folder_exists = os.path.exists(os.path.join(folder, '.git')) | |||||
git_branch = subprocess.run( | |||||
'git branch'.split(), | |||||
cwd=folder, | |||||
stdout=subprocess.PIPE, | |||||
stderr=subprocess.PIPE) | |||||
return folder_exists and git_branch.returncode == 0 | |||||
url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||||
self.git_wrapper.add(self.model_dir, files, all_files) | |||||
self.git_wrapper.commit(self.model_dir, commit_message) | |||||
self.git_wrapper.push( | |||||
repo_dir=self.model_dir, | |||||
token=self.auth_token, | |||||
url=url, | |||||
local_branch=branch, | |||||
remote_branch=branch) |
@@ -1,40 +0,0 @@ | |||||
import subprocess | |||||
from typing import List | |||||
def run_subprocess(command: List[str], | |||||
folder: str, | |||||
check=True, | |||||
**kwargs) -> subprocess.CompletedProcess: | |||||
""" | |||||
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, | |||||
please call `subprocess.run` manually in case you would like for them not to | |||||
be captured. | |||||
Args: | |||||
command (`List[str]`): | |||||
The command to execute as a list of strings. | |||||
folder (`str`): | |||||
The folder in which to run the command. | |||||
check (`bool`, *optional*, defaults to `True`): | |||||
Setting `check` to `True` will raise a `subprocess.CalledProcessError` | |||||
when the subprocess has a non-zero exit code. | |||||
kwargs (`Dict[str]`): | |||||
Keyword arguments to be passed to the `subprocess.run` underlying command. | |||||
Returns: | |||||
`subprocess.CompletedProcess`: The completed process. | |||||
""" | |||||
if isinstance(command, str): | |||||
raise ValueError( | |||||
'`run_subprocess` should be called with a list of strings.') | |||||
return subprocess.run( | |||||
command, | |||||
stderr=subprocess.PIPE, | |||||
stdout=subprocess.PIPE, | |||||
check=check, | |||||
encoding='utf-8', | |||||
cwd=folder, | |||||
**kwargs, | |||||
) |
@@ -1,14 +1,13 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | import os | ||||
import subprocess | |||||
import tempfile | import tempfile | ||||
import unittest | import unittest | ||||
import uuid | import uuid | ||||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
from modelscope.hub.api import HubApi | |||||
from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
from modelscope.hub.repository import Repository | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.hub.utils.utils import get_gitlab_domain | |||||
USER_NAME = 'maasadmin' | USER_NAME = 'maasadmin' | ||||
PASSWORD = '12345678' | PASSWORD = '12345678' | ||||
@@ -17,40 +16,7 @@ model_chinese_name = '达摩卡通化模型' | |||||
model_org = 'unittest' | model_org = 'unittest' | ||||
DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
class GitError(Exception): | |||||
pass | |||||
# TODO make thest git operation to git library after merge code. | |||||
def run_git_command(git_path, *args) -> subprocess.CompletedProcess: | |||||
response = subprocess.run([git_path, *args], capture_output=True) | |||||
try: | |||||
response.check_returncode() | |||||
return response.stdout.decode('utf8') | |||||
except subprocess.CalledProcessError as error: | |||||
raise GitError(error.stderr.decode('utf8')) | |||||
# for public project, token can None, private repo, there must token. | |||||
def clone(local_dir: str, token: str, url: str): | |||||
url = url.replace('//', '//oauth2:%s@' % token) | |||||
clone_args = '-C %s clone %s' % (local_dir, url) | |||||
clone_args = clone_args.split(' ') | |||||
stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) | |||||
print('stdout: %s' % stdout) | |||||
def push(local_dir: str, token: str, url: str): | |||||
url = url.replace('//', '//oauth2:%s@' % token) | |||||
push_args = '-C %s push %s' % (local_dir, url) | |||||
push_args = push_args.split(' ') | |||||
stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) | |||||
print('stdout: %s' % stdout) | |||||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||||
download_model_file_name = 'mnist-12.onnx' | |||||
download_model_file_name = 'test.bin' | |||||
class HubOperationTest(unittest.TestCase): | class HubOperationTest(unittest.TestCase): | ||||
@@ -67,6 +33,13 @@ class HubOperationTest(unittest.TestCase): | |||||
chinese_name=model_chinese_name, | chinese_name=model_chinese_name, | ||||
visibility=5, # 1-private, 5-public | visibility=5, # 1-private, 5-public | ||||
license='apache-2.0') | license='apache-2.0') | ||||
temporary_dir = tempfile.mkdtemp() | |||||
self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
os.chdir(self.model_dir) | |||||
os.system("echo 'testtest'>%s" | |||||
% os.path.join(self.model_dir, 'test.bin')) | |||||
repo.push('add model', all_files=True) | |||||
def tearDown(self): | def tearDown(self): | ||||
os.chdir(self.old_cwd) | os.chdir(self.old_cwd) | ||||
@@ -83,43 +56,10 @@ class HubOperationTest(unittest.TestCase): | |||||
else: | else: | ||||
raise | raise | ||||
# Note that this can be done via git operation once model repo | |||||
# has been created. Git-Op is the RECOMMENDED model upload approach | |||||
def test_model_upload(self): | |||||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
print(url) | |||||
temporary_dir = tempfile.mkdtemp() | |||||
os.chdir(temporary_dir) | |||||
cmd_args = 'clone %s' % url | |||||
cmd_args = cmd_args.split(' ') | |||||
out = run_git_command('git', *cmd_args) | |||||
print(out) | |||||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
os.chdir(repo_dir) | |||||
os.system('touch file1') | |||||
os.system('git add file1') | |||||
os.system("git commit -m 'Test'") | |||||
token = ModelScopeConfig.get_token() | |||||
push(repo_dir, token, url) | |||||
def test_download_single_file(self): | def test_download_single_file(self): | ||||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
print(url) | |||||
temporary_dir = tempfile.mkdtemp() | |||||
os.chdir(temporary_dir) | |||||
os.system('git clone %s' % url) | |||||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
os.chdir(repo_dir) | |||||
os.system('wget %s' % sample_model_url) | |||||
os.system('git add .') | |||||
os.system("git commit -m 'Add file'") | |||||
token = ModelScopeConfig.get_token() | |||||
push(repo_dir, token, url) | |||||
assert os.path.exists( | |||||
os.path.join(temporary_dir, self.model_name, | |||||
download_model_file_name)) | |||||
downloaded_file = model_file_download( | downloaded_file = model_file_download( | ||||
model_id=self.model_id, file_path=download_model_file_name) | model_id=self.model_id, file_path=download_model_file_name) | ||||
assert os.path.exists(downloaded_file) | |||||
mdtime1 = os.path.getmtime(downloaded_file) | mdtime1 = os.path.getmtime(downloaded_file) | ||||
# download again | # download again | ||||
downloaded_file = model_file_download( | downloaded_file = model_file_download( | ||||
@@ -128,18 +68,6 @@ class HubOperationTest(unittest.TestCase): | |||||
assert mdtime1 == mdtime2 | assert mdtime1 == mdtime2 | ||||
def test_snapshot_download(self): | def test_snapshot_download(self): | ||||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
print(url) | |||||
temporary_dir = tempfile.mkdtemp() | |||||
os.chdir(temporary_dir) | |||||
os.system('git clone %s' % url) | |||||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
os.chdir(repo_dir) | |||||
os.system('wget %s' % sample_model_url) | |||||
os.system('git add .') | |||||
os.system("git commit -m 'Add file'") | |||||
token = ModelScopeConfig.get_token() | |||||
push(repo_dir, token, url) | |||||
snapshot_path = snapshot_download(model_id=self.model_id) | snapshot_path = snapshot_download(model_id=self.model_id) | ||||
downloaded_file_path = os.path.join(snapshot_path, | downloaded_file_path = os.path.join(snapshot_path, | ||||
download_model_file_name) | download_model_file_name) | ||||
@@ -0,0 +1,76 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import tempfile | |||||
import unittest | |||||
import uuid | |||||
from modelscope.hub.api import HubApi | |||||
from modelscope.hub.errors import GitError | |||||
from modelscope.hub.repository import Repository | |||||
USER_NAME = 'maasadmin' | |||||
PASSWORD = '12345678' | |||||
USER_NAME2 = 'sdkdev' | |||||
model_chinese_name = '达摩卡通化模型' | |||||
model_org = 'unittest' | |||||
DEFAULT_GIT_PATH = 'git' | |||||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||||
download_model_file_name = 'mnist-12.onnx' | |||||
class HubPrivateRepositoryTest(unittest.TestCase): | |||||
def setUp(self): | |||||
self.old_cwd = os.getcwd() | |||||
self.api = HubApi() | |||||
# note this is temporary before official account management is ready | |||||
self.token, _ = self.api.login(USER_NAME, PASSWORD) | |||||
self.model_name = uuid.uuid4().hex | |||||
self.model_id = '%s/%s' % (model_org, self.model_name) | |||||
self.api.create_model( | |||||
model_id=self.model_id, | |||||
chinese_name=model_chinese_name, | |||||
visibility=1, # 1-private, 5-public | |||||
license='apache-2.0') | |||||
def tearDown(self): | |||||
self.api.login(USER_NAME, PASSWORD) | |||||
os.chdir(self.old_cwd) | |||||
self.api.delete_model(model_id=self.model_id) | |||||
def test_clone_private_repo_no_permission(self): | |||||
token, _ = self.api.login(USER_NAME2, PASSWORD) | |||||
temporary_dir = tempfile.mkdtemp() | |||||
local_dir = os.path.join(temporary_dir, self.model_name) | |||||
with self.assertRaises(GitError) as cm: | |||||
Repository(local_dir, clone_from=self.model_id, auth_token=token) | |||||
print(cm.exception) | |||||
assert not os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
def test_clone_private_repo_has_permission(self): | |||||
temporary_dir = tempfile.mkdtemp() | |||||
local_dir = os.path.join(temporary_dir, self.model_name) | |||||
repo1 = Repository( | |||||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||||
print(repo1.model_dir) | |||||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
def test_initlize_repo_multiple_times(self): | |||||
temporary_dir = tempfile.mkdtemp() | |||||
local_dir = os.path.join(temporary_dir, self.model_name) | |||||
repo1 = Repository( | |||||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||||
print(repo1.model_dir) | |||||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
repo2 = Repository( | |||||
local_dir, clone_from=self.model_id, | |||||
auth_token=self.token) # skip clone | |||||
print(repo2.model_dir) | |||||
assert repo1.model_dir == repo2.model_dir | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,107 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import time | |||||
import unittest | |||||
import uuid | |||||
from os.path import expanduser | |||||
from requests import delete | |||||
from modelscope.hub.api import HubApi | |||||
from modelscope.hub.errors import NotExistError | |||||
from modelscope.hub.file_download import model_file_download | |||||
from modelscope.hub.repository import Repository | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
logger.setLevel('DEBUG') | |||||
USER_NAME = 'maasadmin' | |||||
PASSWORD = '12345678' | |||||
model_chinese_name = '达摩卡通化模型' | |||||
model_org = 'unittest' | |||||
DEFAULT_GIT_PATH = 'git' | |||||
download_model_file_name = 'mnist-12.onnx' | |||||
def delete_credential(): | |||||
path_credential = expanduser('~/.modelscope/credentials') | |||||
shutil.rmtree(path_credential) | |||||
def delete_stored_git_credential(user): | |||||
credential_path = expanduser('~/.git-credentials') | |||||
if os.path.exists(credential_path): | |||||
with open(credential_path, 'r+') as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
if user in line: | |||||
lines.remove(line) | |||||
f.seek(0) | |||||
f.write(''.join(lines)) | |||||
f.truncate() | |||||
class HubRepositoryTest(unittest.TestCase): | |||||
def setUp(self): | |||||
self.api = HubApi() | |||||
# note this is temporary before official account management is ready | |||||
self.api.login(USER_NAME, PASSWORD) | |||||
self.model_name = uuid.uuid4().hex | |||||
self.model_id = '%s/%s' % (model_org, self.model_name) | |||||
self.api.create_model( | |||||
model_id=self.model_id, | |||||
chinese_name=model_chinese_name, | |||||
visibility=5, # 1-private, 5-public | |||||
license='apache-2.0') | |||||
temporary_dir = tempfile.mkdtemp() | |||||
self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
def tearDown(self): | |||||
self.api.delete_model(model_id=self.model_id) | |||||
def test_clone_repo(self): | |||||
Repository(self.model_dir, clone_from=self.model_id) | |||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
def test_clone_public_model_without_token(self): | |||||
delete_credential() | |||||
delete_stored_git_credential(USER_NAME) | |||||
Repository(self.model_dir, clone_from=self.model_id) | |||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
self.api.login(USER_NAME, PASSWORD) # re-login for delete | |||||
def test_push_all(self): | |||||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
os.chdir(self.model_dir) | |||||
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||||
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||||
repo.push('test', all_files=True) | |||||
add1 = model_file_download(self.model_id, 'add1.py') | |||||
assert os.path.exists(add1) | |||||
add2 = model_file_download(self.model_id, 'add2.py') | |||||
assert os.path.exists(add2) | |||||
def test_push_files(self): | |||||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||||
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||||
os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) | |||||
repo.push('test', files=['add1.py', 'add2.py'], all_files=False) | |||||
add1 = model_file_download(self.model_id, 'add1.py') | |||||
assert os.path.exists(add1) | |||||
add2 = model_file_download(self.model_id, 'add2.py') | |||||
assert os.path.exists(add2) | |||||
with self.assertRaises(NotExistError) as cm: | |||||
model_file_download(self.model_id, 'add3.py') | |||||
print(cm.exception) | |||||
if __name__ == '__main__': | |||||
unittest.main() |