@@ -9,9 +9,10 @@ from typing import List, Optional, Tuple, Union | |||
import requests | |||
from modelscope.utils.logger import get_logger | |||
from .constants import LOGGER_NAME | |||
from .constants import MODELSCOPE_URL_SCHEME | |||
from .errors import NotExistError, is_ok, raise_on_error | |||
from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||
from .utils.utils import (get_endpoint, get_gitlab_domain, | |||
model_id_to_group_owner_name) | |||
logger = get_logger() | |||
@@ -40,9 +41,6 @@ class HubApi: | |||
<Tip> | |||
You only have to login once within 30 days. | |||
</Tip> | |||
TODO: handle cookies expire | |||
""" | |||
path = f'{self.endpoint}/api/v1/login' | |||
r = requests.post( | |||
@@ -94,14 +92,14 @@ class HubApi: | |||
'Path': owner_or_group, | |||
'Name': name, | |||
'ChineseName': chinese_name, | |||
'Visibility': visibility, | |||
'Visibility': visibility, # server check | |||
'License': license | |||
}, | |||
cookies=cookies) | |||
r.raise_for_status() | |||
raise_on_error(r.json()) | |||
d = r.json() | |||
return d['Data']['Name'] | |||
model_repo_url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||
return model_repo_url | |||
def delete_model(self, model_id): | |||
"""_summary_ | |||
@@ -209,25 +207,37 @@ class HubApi: | |||
class ModelScopeConfig: | |||
path_credential = expanduser('~/.modelscope/credentials') | |||
os.makedirs(path_credential, exist_ok=True) | |||
@classmethod | |||
def make_sure_credential_path_exist(cls): | |||
os.makedirs(cls.path_credential, exist_ok=True) | |||
@classmethod | |||
def save_cookies(cls, cookies: CookieJar): | |||
cls.make_sure_credential_path_exist() | |||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | |||
pickle.dump(cookies, f) | |||
@classmethod | |||
def get_cookies(cls): | |||
try: | |||
with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: | |||
return pickle.load(f) | |||
cookies_path = os.path.join(cls.path_credential, 'cookies') | |||
with open(cookies_path, 'rb') as f: | |||
cookies = pickle.load(f) | |||
for cookie in cookies: | |||
if cookie.is_expired(): | |||
logger.warn('Auth is expored, please re-login') | |||
return None | |||
return cookies | |||
except FileNotFoundError: | |||
logger.warn("Auth token does not exist, you'll get authentication \ | |||
error when downloading private model files. Please login first" | |||
) | |||
logger.warn( | |||
"Auth token does not exist, you'll get authentication error when downloading \ | |||
private model files. Please login first") | |||
return None | |||
@classmethod | |||
def save_token(cls, token: str): | |||
cls.make_sure_credential_path_exist() | |||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | |||
f.write(token) | |||
@@ -6,6 +6,10 @@ class RequestError(Exception): | |||
pass | |||
class GitError(Exception): | |||
pass | |||
def is_ok(rsp): | |||
""" Check the request is ok | |||
@@ -1,82 +1,161 @@ | |||
from threading import local | |||
from tkinter.messagebox import NO | |||
from typing import Union | |||
import subprocess | |||
from typing import List | |||
from xmlrpc.client import Boolean | |||
from modelscope.utils.logger import get_logger | |||
from .constants import LOGGER_NAME | |||
from .utils._subprocess import run_subprocess | |||
from .errors import GitError | |||
logger = get_logger | |||
logger = get_logger() | |||
def git_clone( | |||
local_dir: str, | |||
repo_url: str, | |||
): | |||
# TODO: use "git clone" or "git lfs clone" according to git version | |||
# TODO: print stderr when subprocess fails | |||
run_subprocess( | |||
f'git clone {repo_url}'.split(), | |||
local_dir, | |||
True, | |||
) | |||
class Singleton(type): | |||
_instances = {} | |||
def __call__(cls, *args, **kwargs): | |||
if cls not in cls._instances: | |||
cls._instances[cls] = super(Singleton, | |||
cls).__call__(*args, **kwargs) | |||
return cls._instances[cls] | |||
def git_checkout( | |||
local_dir: str, | |||
revsion: str, | |||
): | |||
run_subprocess(f'git checkout {revsion}'.split(), local_dir) | |||
def git_add(local_dir: str, ): | |||
run_subprocess( | |||
'git add .'.split(), | |||
local_dir, | |||
True, | |||
) | |||
def git_commit(local_dir: str, commit_message: str): | |||
run_subprocess( | |||
'git commit -v -m'.split() + [commit_message], | |||
local_dir, | |||
True, | |||
) | |||
def git_push(local_dir: str, branch: str): | |||
# check current branch | |||
cur_branch = git_current_branch(local_dir) | |||
if cur_branch != branch: | |||
logger.error( | |||
"You're trying to push to a different branch, please double check") | |||
return | |||
run_subprocess( | |||
f'git push origin {branch}'.split(), | |||
local_dir, | |||
True, | |||
) | |||
def git_current_branch(local_dir: str) -> Union[str, None]: | |||
""" | |||
Get current branch name | |||
Args: | |||
local_dir(`str`): local model repo directory | |||
Returns | |||
branch name you're currently on | |||
class GitCommandWrapper(metaclass=Singleton): | |||
"""Some git operation wrapper | |||
""" | |||
try: | |||
process = run_subprocess( | |||
'git rev-parse --abbrev-ref HEAD'.split(), | |||
local_dir, | |||
True, | |||
) | |||
return str(process.stdout).strip() | |||
except Exception as e: | |||
raise e | |||
default_git_path = 'git' # The default git command line | |||
def __init__(self, path: str = None): | |||
self.git_path = path or self.default_git_path | |||
def _run_git_command(self, *args) -> subprocess.CompletedProcess: | |||
"""Run git command, if command return 0, return subprocess.response | |||
otherwise raise GitError, message is stdout and stderr. | |||
Raises: | |||
GitError: Exception with stdout and stderr. | |||
Returns: | |||
subprocess.CompletedProcess: the command response | |||
""" | |||
logger.info(' '.join(args)) | |||
response = subprocess.run( | |||
[self.git_path, *args], | |||
stdout=subprocess.PIPE, | |||
stderr=subprocess.PIPE) # compatible for python3.6 | |||
try: | |||
response.check_returncode() | |||
return response | |||
except subprocess.CalledProcessError as error: | |||
raise GitError( | |||
'stdout: %s, stderr: %s' % | |||
(response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | |||
def _add_token(self, token: str, url: str): | |||
if token: | |||
if '//oauth2' not in url: | |||
url = url.replace('//', '//oauth2:%s@' % token) | |||
return url | |||
def remove_token_from_url(self, url: str): | |||
if url and '//oauth2' in url: | |||
start_index = url.find('oauth2') | |||
end_index = url.find('@') | |||
url = url[:start_index] + url[end_index + 1:] | |||
return url | |||
def is_lfs_installed(self): | |||
cmd = ['lfs', 'env'] | |||
try: | |||
self._run_git_command(*cmd) | |||
return True | |||
except GitError: | |||
return False | |||
def clone(self, | |||
repo_base_dir: str, | |||
token: str, | |||
url: str, | |||
repo_name: str, | |||
branch: str = None): | |||
""" git clone command wrapper. | |||
For public project, token can None, private repo, there must token. | |||
Args: | |||
repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name | |||
token (str): The git token, must be provided for private project. | |||
url (str): The remote url | |||
repo_name (str): The local repository path name. | |||
branch (str, optional): _description_. Defaults to None. | |||
""" | |||
url = self._add_token(token, url) | |||
if branch: | |||
clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url, | |||
repo_name, branch) | |||
else: | |||
clone_args = '-C %s clone %s' % (repo_base_dir, url) | |||
logger.debug(clone_args) | |||
clone_args = clone_args.split(' ') | |||
response = self._run_git_command(*clone_args) | |||
logger.info(response.stdout.decode('utf8')) | |||
return response | |||
def add(self, | |||
repo_dir: str, | |||
files: List[str] = list(), | |||
all_files: bool = False): | |||
if all_files: | |||
add_args = '-C %s add -A' % repo_dir | |||
elif len(files) > 0: | |||
files_str = ' '.join(files) | |||
add_args = '-C %s add %s' % (repo_dir, files_str) | |||
add_args = add_args.split(' ') | |||
rsp = self._run_git_command(*add_args) | |||
logger.info(rsp.stdout.decode('utf8')) | |||
return rsp | |||
def commit(self, repo_dir: str, message: str): | |||
"""Run git commit command | |||
Args: | |||
message (str): commit message. | |||
""" | |||
commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message] | |||
rsp = self._run_git_command(*commit_args) | |||
logger.info(rsp.stdout.decode('utf8')) | |||
return rsp | |||
def checkout(self, repo_dir: str, revision: str): | |||
cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision] | |||
return self._run_git_command(*cmds) | |||
def new_branch(self, repo_dir: str, revision: str): | |||
cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | |||
return self._run_git_command(*cmds) | |||
def pull(self, repo_dir: str): | |||
cmds = ['-C', repo_dir, 'pull'] | |||
return self._run_git_command(*cmds) | |||
def push(self, | |||
repo_dir: str, | |||
token: str, | |||
url: str, | |||
local_branch: str, | |||
remote_branch: str, | |||
force: bool = False): | |||
url = self._add_token(token, url) | |||
push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch, | |||
remote_branch) | |||
if force: | |||
push_args += ' -f' | |||
push_args = push_args.split(' ') | |||
rsp = self._run_git_command(*push_args) | |||
logger.info(rsp.stdout.decode('utf8')) | |||
return rsp | |||
def get_repo_remote_url(self, repo_dir: str): | |||
cmd_args = '-C %s config --get remote.origin.url' % repo_dir | |||
cmd_args = cmd_args.split(' ') | |||
rsp = self._run_git_command(*cmd_args) | |||
url = rsp.stdout.decode('utf8') | |||
return url.strip() |
@@ -1,173 +1,97 @@ | |||
import os | |||
import subprocess | |||
from pathlib import Path | |||
from typing import Optional, Union | |||
from typing import List, Optional | |||
from modelscope.hub.errors import GitError | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from .constants import MODELSCOPE_URL_SCHEME | |||
from .git import git_add, git_checkout, git_clone, git_commit, git_push | |||
from .utils._subprocess import run_subprocess | |||
from .git import GitCommandWrapper | |||
from .utils.utils import get_gitlab_domain | |||
logger = get_logger() | |||
class Repository: | |||
"""Representation local model git repository. | |||
""" | |||
def __init__( | |||
self, | |||
local_dir: str, | |||
clone_from: Optional[str] = None, | |||
auth_token: Optional[str] = None, | |||
private: Optional[bool] = False, | |||
model_dir: str, | |||
clone_from: str, | |||
revision: Optional[str] = 'master', | |||
auth_token: Optional[str] = None, | |||
git_path: Optional[str] = None, | |||
): | |||
""" | |||
Instantiate a Repository object by cloning the remote ModelScopeHub repo | |||
Args: | |||
local_dir(`str`): | |||
local directory to store the model files | |||
clone_from(`Optional[str] = None`): | |||
model_dir(`str`): | |||
The model root directory. | |||
clone_from: | |||
model id in ModelScope-hub from which git clone | |||
You should ignore this parameter when `local_dir` is already a git repo | |||
auth_token(`Optional[str]`): | |||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
as the token is already saved when you login the first time | |||
private(`Optional[bool]`): | |||
whether the model is private, default to False | |||
revision(`Optional[str]`): | |||
revision of the model you want to clone from. Can be any of a branch, tag or commit hash | |||
auth_token(`Optional[str]`): | |||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
as the token is already saved when you login the first time, if None, we will use saved token. | |||
git_path:(`Optional[str]`): | |||
The git command line path, if None, we use 'git' | |||
""" | |||
logger.info('Instantiating Repository object...') | |||
# Create local directory if not exist | |||
os.makedirs(local_dir, exist_ok=True) | |||
self.local_dir = os.path.join(os.getcwd(), local_dir) | |||
self.private = private | |||
# Check git and git-lfs installation | |||
self.check_git_versions() | |||
# Retrieve auth token | |||
if not private and isinstance(auth_token, str): | |||
logger.warning( | |||
'cloning a public repo with a token, which will be ignored') | |||
self.token = None | |||
self.model_dir = model_dir | |||
self.model_base_dir = os.path.dirname(model_dir) | |||
self.model_repo_name = os.path.basename(model_dir) | |||
if auth_token: | |||
self.auth_token = auth_token | |||
else: | |||
if isinstance(auth_token, str): | |||
self.token = auth_token | |||
else: | |||
self.token = ModelScopeConfig.get_token() | |||
if self.token is None: | |||
raise EnvironmentError( | |||
'Token does not exist, the clone will fail for private repo.' | |||
'Please login first.') | |||
# git clone | |||
if clone_from is not None: | |||
self.model_id = clone_from | |||
logger.info('cloning model repo to %s ...', self.local_dir) | |||
git_clone(self.local_dir, self.get_repo_url()) | |||
else: | |||
if is_git_repo(self.local_dir): | |||
logger.debug('[Repository] is a valid git repo') | |||
else: | |||
raise ValueError( | |||
'If not specifying `clone_from`, you need to pass Repository a' | |||
' valid git clone.') | |||
# git checkout | |||
if isinstance(revision, str) and revision != 'master': | |||
git_checkout(revision) | |||
def push_to_hub(self, | |||
commit_message: str, | |||
revision: Optional[str] = 'master'): | |||
""" | |||
Push changes changes to hub | |||
Args: | |||
commit_message(`str`): | |||
commit message describing the changes, it's mandatory | |||
revision(`Optional[str]`): | |||
remote branch you want to push to, default to `master` | |||
<Tip> | |||
The function complains when local and remote branch are different, please be careful | |||
</Tip> | |||
""" | |||
git_add(self.local_dir) | |||
git_commit(self.local_dir, commit_message) | |||
logger.info('Pushing changes to repo...') | |||
git_push(self.local_dir, revision) | |||
# TODO: if git push fails, how to retry? | |||
def check_git_versions(self): | |||
""" | |||
Checks that `git` and `git-lfs` can be run. | |||
Raises: | |||
`EnvironmentError`: if `git` or `git-lfs` are not installed. | |||
""" | |||
try: | |||
git_version = run_subprocess('git --version'.split(), | |||
self.local_dir).stdout.strip() | |||
except FileNotFoundError: | |||
raise EnvironmentError( | |||
'Looks like you do not have git installed, please install.') | |||
self.auth_token = ModelScopeConfig.get_token() | |||
git_wrapper = GitCommandWrapper() | |||
if not git_wrapper.is_lfs_installed(): | |||
logger.error('git lfs is not installed, please install.') | |||
self.git_wrapper = GitCommandWrapper(git_path) | |||
os.makedirs(self.model_dir, exist_ok=True) | |||
url = self._get_model_id_url(clone_from) | |||
if os.listdir(self.model_dir): # directory not empty. | |||
remote_url = self._get_remote_url() | |||
remote_url = self.git_wrapper.remove_token_from_url(remote_url) | |||
if remote_url and remote_url == url: # need not clone again | |||
return | |||
self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, | |||
self.model_repo_name, revision) | |||
def _get_model_id_url(self, model_id): | |||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||
return url | |||
def _get_remote_url(self): | |||
try: | |||
lfs_version = run_subprocess('git-lfs --version'.split(), | |||
self.local_dir).stdout.strip() | |||
except FileNotFoundError: | |||
raise EnvironmentError( | |||
'Looks like you do not have git-lfs installed, please install.' | |||
' You can install from https://git-lfs.github.com/.' | |||
' Then run `git lfs install` (you only have to do this once).') | |||
logger.info(git_version + '\n' + lfs_version) | |||
def get_repo_url(self) -> str: | |||
""" | |||
Get repo url to clone, according whether the repo is private or not | |||
remote = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
except GitError: | |||
remote = None | |||
return remote | |||
def push(self, | |||
commit_message: str, | |||
files: List[str] = list(), | |||
all_files: bool = False, | |||
branch: Optional[str] = 'master', | |||
force: bool = False): | |||
"""Push local to remote, this method will do. | |||
git add | |||
git commit | |||
git push | |||
Args: | |||
commit_message (str): commit message | |||
revision (Optional[str], optional): which branch to push. Defaults to 'master'. | |||
""" | |||
url = None | |||
if self.private: | |||
url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' | |||
else: | |||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' | |||
if not url: | |||
raise ValueError( | |||
'Empty repo url, please check clone_from parameter') | |||
logger.debug('url to clone: %s', str(url)) | |||
return url | |||
def is_git_repo(folder: Union[str, Path]) -> bool: | |||
""" | |||
Check if the folder is the root or part of a git repository | |||
Args: | |||
folder (`str`): | |||
The folder in which to run the command. | |||
Returns: | |||
`bool`: `True` if the repository is part of a repository, `False` | |||
otherwise. | |||
""" | |||
folder_exists = os.path.exists(os.path.join(folder, '.git')) | |||
git_branch = subprocess.run( | |||
'git branch'.split(), | |||
cwd=folder, | |||
stdout=subprocess.PIPE, | |||
stderr=subprocess.PIPE) | |||
return folder_exists and git_branch.returncode == 0 | |||
url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
self.git_wrapper.add(self.model_dir, files, all_files) | |||
self.git_wrapper.commit(self.model_dir, commit_message) | |||
self.git_wrapper.push( | |||
repo_dir=self.model_dir, | |||
token=self.auth_token, | |||
url=url, | |||
local_branch=branch, | |||
remote_branch=branch) |
@@ -1,40 +0,0 @@ | |||
import subprocess | |||
from typing import List | |||
def run_subprocess(command: List[str], | |||
folder: str, | |||
check=True, | |||
**kwargs) -> subprocess.CompletedProcess: | |||
""" | |||
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, | |||
please call `subprocess.run` manually in case you would like for them not to | |||
be captured. | |||
Args: | |||
command (`List[str]`): | |||
The command to execute as a list of strings. | |||
folder (`str`): | |||
The folder in which to run the command. | |||
check (`bool`, *optional*, defaults to `True`): | |||
Setting `check` to `True` will raise a `subprocess.CalledProcessError` | |||
when the subprocess has a non-zero exit code. | |||
kwargs (`Dict[str]`): | |||
Keyword arguments to be passed to the `subprocess.run` underlying command. | |||
Returns: | |||
`subprocess.CompletedProcess`: The completed process. | |||
""" | |||
if isinstance(command, str): | |||
raise ValueError( | |||
'`run_subprocess` should be called with a list of strings.') | |||
return subprocess.run( | |||
command, | |||
stderr=subprocess.PIPE, | |||
stdout=subprocess.PIPE, | |||
check=check, | |||
encoding='utf-8', | |||
cwd=folder, | |||
**kwargs, | |||
) |
@@ -1,14 +1,13 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import subprocess | |||
import tempfile | |||
import unittest | |||
import uuid | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.repository import Repository | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.hub.utils.utils import get_gitlab_domain | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
@@ -17,40 +16,7 @@ model_chinese_name = '达摩卡通化模型' | |||
model_org = 'unittest' | |||
DEFAULT_GIT_PATH = 'git' | |||
class GitError(Exception): | |||
pass | |||
# TODO make thest git operation to git library after merge code. | |||
def run_git_command(git_path, *args) -> subprocess.CompletedProcess: | |||
response = subprocess.run([git_path, *args], capture_output=True) | |||
try: | |||
response.check_returncode() | |||
return response.stdout.decode('utf8') | |||
except subprocess.CalledProcessError as error: | |||
raise GitError(error.stderr.decode('utf8')) | |||
# for public project, token can None, private repo, there must token. | |||
def clone(local_dir: str, token: str, url: str): | |||
url = url.replace('//', '//oauth2:%s@' % token) | |||
clone_args = '-C %s clone %s' % (local_dir, url) | |||
clone_args = clone_args.split(' ') | |||
stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) | |||
print('stdout: %s' % stdout) | |||
def push(local_dir: str, token: str, url: str): | |||
url = url.replace('//', '//oauth2:%s@' % token) | |||
push_args = '-C %s push %s' % (local_dir, url) | |||
push_args = push_args.split(' ') | |||
stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) | |||
print('stdout: %s' % stdout) | |||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||
download_model_file_name = 'mnist-12.onnx' | |||
download_model_file_name = 'test.bin' | |||
class HubOperationTest(unittest.TestCase): | |||
@@ -67,6 +33,13 @@ class HubOperationTest(unittest.TestCase): | |||
chinese_name=model_chinese_name, | |||
visibility=5, # 1-private, 5-public | |||
license='apache-2.0') | |||
temporary_dir = tempfile.mkdtemp() | |||
self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||
os.chdir(self.model_dir) | |||
os.system("echo 'testtest'>%s" | |||
% os.path.join(self.model_dir, 'test.bin')) | |||
repo.push('add model', all_files=True) | |||
def tearDown(self): | |||
os.chdir(self.old_cwd) | |||
@@ -83,43 +56,10 @@ class HubOperationTest(unittest.TestCase): | |||
else: | |||
raise | |||
# Note that this can be done via git operation once model repo | |||
# has been created. Git-Op is the RECOMMENDED model upload approach | |||
def test_model_upload(self): | |||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
print(url) | |||
temporary_dir = tempfile.mkdtemp() | |||
os.chdir(temporary_dir) | |||
cmd_args = 'clone %s' % url | |||
cmd_args = cmd_args.split(' ') | |||
out = run_git_command('git', *cmd_args) | |||
print(out) | |||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||
os.chdir(repo_dir) | |||
os.system('touch file1') | |||
os.system('git add file1') | |||
os.system("git commit -m 'Test'") | |||
token = ModelScopeConfig.get_token() | |||
push(repo_dir, token, url) | |||
def test_download_single_file(self): | |||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
print(url) | |||
temporary_dir = tempfile.mkdtemp() | |||
os.chdir(temporary_dir) | |||
os.system('git clone %s' % url) | |||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||
os.chdir(repo_dir) | |||
os.system('wget %s' % sample_model_url) | |||
os.system('git add .') | |||
os.system("git commit -m 'Add file'") | |||
token = ModelScopeConfig.get_token() | |||
push(repo_dir, token, url) | |||
assert os.path.exists( | |||
os.path.join(temporary_dir, self.model_name, | |||
download_model_file_name)) | |||
downloaded_file = model_file_download( | |||
model_id=self.model_id, file_path=download_model_file_name) | |||
assert os.path.exists(downloaded_file) | |||
mdtime1 = os.path.getmtime(downloaded_file) | |||
# download again | |||
downloaded_file = model_file_download( | |||
@@ -128,18 +68,6 @@ class HubOperationTest(unittest.TestCase): | |||
assert mdtime1 == mdtime2 | |||
def test_snapshot_download(self): | |||
url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
print(url) | |||
temporary_dir = tempfile.mkdtemp() | |||
os.chdir(temporary_dir) | |||
os.system('git clone %s' % url) | |||
repo_dir = os.path.join(temporary_dir, self.model_name) | |||
os.chdir(repo_dir) | |||
os.system('wget %s' % sample_model_url) | |||
os.system('git add .') | |||
os.system("git commit -m 'Add file'") | |||
token = ModelScopeConfig.get_token() | |||
push(repo_dir, token, url) | |||
snapshot_path = snapshot_download(model_id=self.model_id) | |||
downloaded_file_path = os.path.join(snapshot_path, | |||
download_model_file_name) | |||
@@ -0,0 +1,76 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import tempfile | |||
import unittest | |||
import uuid | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.errors import GitError | |||
from modelscope.hub.repository import Repository | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
USER_NAME2 = 'sdkdev' | |||
model_chinese_name = '达摩卡通化模型' | |||
model_org = 'unittest' | |||
DEFAULT_GIT_PATH = 'git' | |||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||
download_model_file_name = 'mnist-12.onnx' | |||
class HubPrivateRepositoryTest(unittest.TestCase): | |||
def setUp(self): | |||
self.old_cwd = os.getcwd() | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.token, _ = self.api.login(USER_NAME, PASSWORD) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (model_org, self.model_name) | |||
self.api.create_model( | |||
model_id=self.model_id, | |||
chinese_name=model_chinese_name, | |||
visibility=1, # 1-private, 5-public | |||
license='apache-2.0') | |||
def tearDown(self): | |||
self.api.login(USER_NAME, PASSWORD) | |||
os.chdir(self.old_cwd) | |||
self.api.delete_model(model_id=self.model_id) | |||
def test_clone_private_repo_no_permission(self): | |||
token, _ = self.api.login(USER_NAME2, PASSWORD) | |||
temporary_dir = tempfile.mkdtemp() | |||
local_dir = os.path.join(temporary_dir, self.model_name) | |||
with self.assertRaises(GitError) as cm: | |||
Repository(local_dir, clone_from=self.model_id, auth_token=token) | |||
print(cm.exception) | |||
assert not os.path.exists(os.path.join(local_dir, 'README.md')) | |||
def test_clone_private_repo_has_permission(self): | |||
temporary_dir = tempfile.mkdtemp() | |||
local_dir = os.path.join(temporary_dir, self.model_name) | |||
repo1 = Repository( | |||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||
print(repo1.model_dir) | |||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||
def test_initlize_repo_multiple_times(self): | |||
temporary_dir = tempfile.mkdtemp() | |||
local_dir = os.path.join(temporary_dir, self.model_name) | |||
repo1 = Repository( | |||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||
print(repo1.model_dir) | |||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||
repo2 = Repository( | |||
local_dir, clone_from=self.model_id, | |||
auth_token=self.token) # skip clone | |||
print(repo2.model_dir) | |||
assert repo1.model_dir == repo2.model_dir | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,107 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import tempfile | |||
import time | |||
import unittest | |||
import uuid | |||
from os.path import expanduser | |||
from requests import delete | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.errors import NotExistError | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.repository import Repository | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
logger.setLevel('DEBUG') | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
model_chinese_name = '达摩卡通化模型' | |||
model_org = 'unittest' | |||
DEFAULT_GIT_PATH = 'git' | |||
download_model_file_name = 'mnist-12.onnx' | |||
def delete_credential(): | |||
path_credential = expanduser('~/.modelscope/credentials') | |||
shutil.rmtree(path_credential) | |||
def delete_stored_git_credential(user): | |||
credential_path = expanduser('~/.git-credentials') | |||
if os.path.exists(credential_path): | |||
with open(credential_path, 'r+') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
if user in line: | |||
lines.remove(line) | |||
f.seek(0) | |||
f.write(''.join(lines)) | |||
f.truncate() | |||
class HubRepositoryTest(unittest.TestCase): | |||
def setUp(self): | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.api.login(USER_NAME, PASSWORD) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (model_org, self.model_name) | |||
self.api.create_model( | |||
model_id=self.model_id, | |||
chinese_name=model_chinese_name, | |||
visibility=5, # 1-private, 5-public | |||
license='apache-2.0') | |||
temporary_dir = tempfile.mkdtemp() | |||
self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
def tearDown(self): | |||
self.api.delete_model(model_id=self.model_id) | |||
def test_clone_repo(self): | |||
Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
def test_clone_public_model_without_token(self): | |||
delete_credential() | |||
delete_stored_git_credential(USER_NAME) | |||
Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
self.api.login(USER_NAME, PASSWORD) # re-login for delete | |||
def test_push_all(self): | |||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
os.chdir(self.model_dir) | |||
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||
repo.push('test', all_files=True) | |||
add1 = model_file_download(self.model_id, 'add1.py') | |||
assert os.path.exists(add1) | |||
add2 = model_file_download(self.model_id, 'add2.py') | |||
assert os.path.exists(add2) | |||
def test_push_files(self): | |||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||
os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) | |||
repo.push('test', files=['add1.py', 'add2.py'], all_files=False) | |||
add1 = model_file_download(self.model_id, 'add1.py') | |||
assert os.path.exists(add1) | |||
add2 = model_file_download(self.model_id, 'add2.py') | |||
assert os.path.exists(add2) | |||
with self.assertRaises(NotExistError) as cm: | |||
model_file_download(self.model_id, 'add3.py') | |||
print(cm.exception) | |||
if __name__ == '__main__': | |||
unittest.main() |