@@ -24,6 +24,8 @@ do | |||||
-v /home/admin/pre-commit:/home/admin/pre-commit \ | -v /home/admin/pre-commit:/home/admin/pre-commit \ | ||||
-e CI_TEST=True \ | -e CI_TEST=True \ | ||||
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | ||||
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
--workdir=$CODE_DIR_IN_CONTAINER \ | --workdir=$CODE_DIR_IN_CONTAINER \ | ||||
--net host \ | --net host \ | ||||
${IMAGE_NAME}:${IMAGE_VERSION} \ | ${IMAGE_NAME}:${IMAGE_VERSION} \ | ||||
@@ -3,12 +3,19 @@ import pickle | |||||
import shutil | import shutil | ||||
import subprocess | import subprocess | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from http import HTTPStatus | |||||
from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
from os.path import expanduser | from os.path import expanduser | ||||
from typing import List, Optional, Tuple, Union | from typing import List, Optional, Tuple, Union | ||||
import requests | import requests | ||||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
API_RESPONSE_FIELD_EMAIL, | |||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||||
API_RESPONSE_FIELD_MESSAGE, | |||||
API_RESPONSE_FIELD_USERNAME, | |||||
DEFAULT_CREDENTIALS_PATH) | |||||
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, | from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, | ||||
HUB_DATASET_ENDPOINT) | HUB_DATASET_ENDPOINT) | ||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
@@ -32,16 +39,13 @@ class HubApi: | |||||
def login( | def login( | ||||
self, | self, | ||||
user_name: str, | |||||
password: str, | |||||
access_token: str, | |||||
) -> tuple(): | ) -> tuple(): | ||||
""" | """ | ||||
Login with username and password | Login with username and password | ||||
Args: | Args: | ||||
user_name(`str`): user name on modelscope | |||||
password(`str`): password | |||||
access_token(`str`): user access token on modelscope. | |||||
Returns: | Returns: | ||||
cookies: to authenticate yourself to ModelScope open-api | cookies: to authenticate yourself to ModelScope open-api | ||||
gitlab token: to access private repos | gitlab token: to access private repos | ||||
@@ -51,24 +55,23 @@ class HubApi: | |||||
</Tip> | </Tip> | ||||
""" | """ | ||||
path = f'{self.endpoint}/api/v1/login' | path = f'{self.endpoint}/api/v1/login' | ||||
r = requests.post( | |||||
path, json={ | |||||
'username': user_name, | |||||
'password': password | |||||
}) | |||||
r = requests.post(path, json={'AccessToken': access_token}) | |||||
r.raise_for_status() | r.raise_for_status() | ||||
d = r.json() | d = r.json() | ||||
raise_on_error(d) | raise_on_error(d) | ||||
token = d['Data']['AccessToken'] | |||||
token = d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_GIT_ACCESS_TOKEN] | |||||
cookies = r.cookies | cookies = r.cookies | ||||
# save token and cookie | # save token and cookie | ||||
ModelScopeConfig.save_token(token) | ModelScopeConfig.save_token(token) | ||||
ModelScopeConfig.save_cookies(cookies) | ModelScopeConfig.save_cookies(cookies) | ||||
ModelScopeConfig.write_to_git_credential(user_name, password) | |||||
ModelScopeConfig.save_user_info( | |||||
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_USERNAME], | |||||
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_EMAIL]) | |||||
return d['Data']['AccessToken'], cookies | |||||
return d[API_RESPONSE_FIELD_DATA][ | |||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies | |||||
def create_model( | def create_model( | ||||
self, | self, | ||||
@@ -161,11 +164,11 @@ class HubApi: | |||||
r = requests.get(path, cookies=cookies) | r = requests.get(path, cookies=cookies) | ||||
handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
if r.status_code == 200: | |||||
if r.status_code == HTTPStatus.OK: | |||||
if is_ok(r.json()): | if is_ok(r.json()): | ||||
return r.json()['Data'] | |||||
return r.json()[API_RESPONSE_FIELD_DATA] | |||||
else: | else: | ||||
raise NotExistError(r.json()['Message']) | |||||
raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
else: | else: | ||||
r.raise_for_status() | r.raise_for_status() | ||||
@@ -189,12 +192,12 @@ class HubApi: | |||||
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | ||||
(owner_or_group, page_number, page_size)) | (owner_or_group, page_number, page_size)) | ||||
handle_http_response(r, logger, cookies, 'list_model') | handle_http_response(r, logger, cookies, 'list_model') | ||||
if r.status_code == 200: | |||||
if r.status_code == HTTPStatus.OK: | |||||
if is_ok(r.json()): | if is_ok(r.json()): | ||||
data = r.json()['Data'] | |||||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
return data | return data | ||||
else: | else: | ||||
raise RequestError(r.json()['Message']) | |||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
else: | else: | ||||
r.raise_for_status() | r.raise_for_status() | ||||
return None | return None | ||||
@@ -232,7 +235,7 @@ class HubApi: | |||||
handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
d = r.json() | d = r.json() | ||||
raise_on_error(d) | raise_on_error(d) | ||||
info = d['Data'] | |||||
info = d[API_RESPONSE_FIELD_DATA] | |||||
branches = [x['Revision'] for x in info['RevisionMap']['Branches'] | branches = [x['Revision'] for x in info['RevisionMap']['Branches'] | ||||
] if info['RevisionMap']['Branches'] else [] | ] if info['RevisionMap']['Branches'] else [] | ||||
tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | ||||
@@ -276,7 +279,7 @@ class HubApi: | |||||
raise_on_error(d) | raise_on_error(d) | ||||
files = [] | files = [] | ||||
for file in d['Data']['Files']: | |||||
for file in d[API_RESPONSE_FIELD_DATA]['Files']: | |||||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': | if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': | ||||
continue | continue | ||||
@@ -289,7 +292,7 @@ class HubApi: | |||||
params = {} | params = {} | ||||
r = requests.get(path, params=params, headers=headers) | r = requests.get(path, params=params, headers=headers) | ||||
r.raise_for_status() | r.raise_for_status() | ||||
dataset_list = r.json()['Data'] | |||||
dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | |||||
return [x['Name'] for x in dataset_list] | return [x['Name'] for x in dataset_list] | ||||
def fetch_dataset_scripts( | def fetch_dataset_scripts( | ||||
@@ -379,21 +382,27 @@ class HubApi: | |||||
class ModelScopeConfig: | class ModelScopeConfig: | ||||
path_credential = expanduser('~/.modelscope/credentials') | |||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | |||||
COOKIES_FILE_NAME = 'cookies' | |||||
GIT_TOKEN_FILE_NAME = 'git_token' | |||||
USER_INFO_FILE_NAME = 'user' | |||||
@classmethod | |||||
def make_sure_credential_path_exist(cls): | |||||
os.makedirs(cls.path_credential, exist_ok=True) | |||||
@staticmethod | |||||
def make_sure_credential_path_exist(): | |||||
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True) | |||||
@classmethod | |||||
def save_cookies(cls, cookies: CookieJar): | |||||
cls.make_sure_credential_path_exist() | |||||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | |||||
@staticmethod | |||||
def save_cookies(cookies: CookieJar): | |||||
ModelScopeConfig.make_sure_credential_path_exist() | |||||
with open( | |||||
os.path.join(ModelScopeConfig.path_credential, | |||||
ModelScopeConfig.COOKIES_FILE_NAME), 'wb+') as f: | |||||
pickle.dump(cookies, f) | pickle.dump(cookies, f) | ||||
@classmethod | |||||
def get_cookies(cls): | |||||
cookies_path = os.path.join(cls.path_credential, 'cookies') | |||||
@staticmethod | |||||
def get_cookies(): | |||||
cookies_path = os.path.join(ModelScopeConfig.path_credential, | |||||
ModelScopeConfig.COOKIES_FILE_NAME) | |||||
if os.path.exists(cookies_path): | if os.path.exists(cookies_path): | ||||
with open(cookies_path, 'rb') as f: | with open(cookies_path, 'rb') as f: | ||||
cookies = pickle.load(f) | cookies = pickle.load(f) | ||||
@@ -405,14 +414,38 @@ class ModelScopeConfig: | |||||
return cookies | return cookies | ||||
return None | return None | ||||
@classmethod | |||||
def save_token(cls, token: str): | |||||
cls.make_sure_credential_path_exist() | |||||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | |||||
@staticmethod | |||||
def save_token(token: str): | |||||
ModelScopeConfig.make_sure_credential_path_exist() | |||||
with open( | |||||
os.path.join(ModelScopeConfig.path_credential, | |||||
ModelScopeConfig.GITLAB_TOKEN_FILE_NAME), | |||||
'w+') as f: | |||||
f.write(token) | f.write(token) | ||||
@classmethod | |||||
def get_token(cls) -> Optional[str]: | |||||
@staticmethod | |||||
def save_user_info(user_name: str, user_email: str): | |||||
ModelScopeConfig.make_sure_credential_path_exist() | |||||
with open( | |||||
os.path.join(ModelScopeConfig.path_credential, | |||||
ModelScopeConfig.USER_INFO_FILE_NAME), 'w+') as f: | |||||
f.write('%s:%s' % (user_name, user_email)) | |||||
@staticmethod | |||||
def get_user_info() -> Tuple[str, str]: | |||||
try: | |||||
with open( | |||||
os.path.join(ModelScopeConfig.path_credential, | |||||
ModelScopeConfig.USER_INFO_FILE_NAME), | |||||
'r') as f: | |||||
info = f.read() | |||||
return info.split(':')[0], info.split(':')[1] | |||||
except FileNotFoundError: | |||||
pass | |||||
return None, None | |||||
@staticmethod | |||||
def get_token() -> Optional[str]: | |||||
""" | """ | ||||
Get token or None if not existent. | Get token or None if not existent. | ||||
@@ -422,24 +455,11 @@ class ModelScopeConfig: | |||||
""" | """ | ||||
token = None | token = None | ||||
try: | try: | ||||
with open(os.path.join(cls.path_credential, 'token'), 'r') as f: | |||||
with open( | |||||
os.path.join(ModelScopeConfig.path_credential, | |||||
ModelScopeConfig.GITLAB_TOKEN_FILE_NAME), | |||||
'r') as f: | |||||
token = f.read() | token = f.read() | ||||
except FileNotFoundError: | except FileNotFoundError: | ||||
pass | pass | ||||
return token | return token | ||||
@staticmethod | |||||
def write_to_git_credential(username: str, password: str): | |||||
with subprocess.Popen( | |||||
'git credential-store store'.split(), | |||||
stdin=subprocess.PIPE, | |||||
stdout=subprocess.PIPE, | |||||
stderr=subprocess.STDOUT, | |||||
) as process: | |||||
input_username = f'username={username.lower()}' | |||||
input_password = f'password={password}' | |||||
process.stdin.write( | |||||
f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n' | |||||
.encode('utf-8')) | |||||
process.stdin.flush() |
@@ -1,12 +1,17 @@ | |||||
MODELSCOPE_URL_SCHEME = 'http://' | MODELSCOPE_URL_SCHEME = 'http://' | ||||
DEFAULT_MODELSCOPE_IP = '123.57.147.185' | |||||
DEFAULT_MODELSCOPE_DOMAIN = DEFAULT_MODELSCOPE_IP + ':31090' | |||||
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_IP + ':31090' | |||||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn' | |||||
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN | |||||
DEFAULT_MODELSCOPE_GROUP = 'damo' | DEFAULT_MODELSCOPE_GROUP = 'damo' | ||||
MODEL_ID_SEPARATOR = '/' | MODEL_ID_SEPARATOR = '/' | ||||
LOGGER_NAME = 'ModelScopeHub' | LOGGER_NAME = 'ModelScopeHub' | ||||
DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | |||||
API_RESPONSE_FIELD_DATA = 'Data' | |||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||||
API_RESPONSE_FIELD_USERNAME = 'Username' | |||||
API_RESPONSE_FIELD_EMAIL = 'Email' | |||||
API_RESPONSE_FIELD_MESSAGE = 'Message' | |||||
class Licenses(object): | class Licenses(object): | ||||
@@ -1,3 +1,5 @@ | |||||
from http import HTTPStatus | |||||
from requests.exceptions import HTTPError | from requests.exceptions import HTTPError | ||||
@@ -17,6 +19,10 @@ class InvalidParameter(Exception): | |||||
pass | pass | ||||
class NotLoginException(Exception): | |||||
pass | |||||
def is_ok(rsp): | def is_ok(rsp): | ||||
""" Check the request is ok | """ Check the request is ok | ||||
@@ -26,7 +32,7 @@ def is_ok(rsp): | |||||
'RequestId': '', 'Success': False} | 'RequestId': '', 'Success': False} | ||||
Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True} | Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True} | ||||
""" | """ | ||||
return rsp['Code'] == 200 and rsp['Success'] | |||||
return rsp['Code'] == HTTPStatus.OK and rsp['Success'] | |||||
def handle_http_response(response, logger, cookies, model_id): | def handle_http_response(response, logger, cookies, model_id): | ||||
@@ -46,7 +52,7 @@ def raise_on_error(rsp): | |||||
Args: | Args: | ||||
rsp (_type_): The server response | rsp (_type_): The server response | ||||
""" | """ | ||||
if rsp['Code'] == 200 and rsp['Success']: | |||||
if rsp['Code'] == HTTPStatus.OK and rsp['Success']: | |||||
return True | return True | ||||
else: | else: | ||||
raise RequestError(rsp['Message']) | raise RequestError(rsp['Message']) | ||||
@@ -59,7 +65,7 @@ def datahub_raise_on_error(url, rsp): | |||||
Args: | Args: | ||||
rsp (_type_): The server response | rsp (_type_): The server response | ||||
""" | """ | ||||
if rsp.get('Code') == 200: | |||||
if rsp.get('Code') == HTTPStatus.OK: | |||||
return True | return True | ||||
else: | else: | ||||
raise RequestError( | raise RequestError( | ||||
@@ -1,30 +1,22 @@ | |||||
import copy | import copy | ||||
import fnmatch | |||||
import logging | |||||
import os | import os | ||||
import sys | import sys | ||||
import tempfile | import tempfile | ||||
import time | |||||
from functools import partial | from functools import partial | ||||
from hashlib import sha256 | |||||
from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
from pathlib import Path | from pathlib import Path | ||||
from typing import BinaryIO, Dict, Optional, Union | |||||
from typing import Dict, Optional, Union | |||||
from uuid import uuid4 | from uuid import uuid4 | ||||
import json | |||||
import requests | import requests | ||||
from filelock import FileLock | from filelock import FileLock | ||||
from requests.exceptions import HTTPError | |||||
from tqdm import tqdm | from tqdm import tqdm | ||||
from modelscope import __version__ | from modelscope import __version__ | ||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .api import HubApi, ModelScopeConfig | from .api import HubApi, ModelScopeConfig | ||||
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME, | |||||
MODEL_ID_SEPARATOR) | |||||
from .errors import NotExistError, RequestError, raise_on_error | |||||
from .errors import NotExistError | |||||
from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
from .utils.utils import (get_cache_dir, get_endpoint, | from .utils.utils import (get_cache_dir, get_endpoint, | ||||
model_id_to_group_owner_name) | model_id_to_group_owner_name) | ||||
@@ -4,6 +4,7 @@ from typing import List | |||||
from xmlrpc.client import Boolean | from xmlrpc.client import Boolean | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .api import ModelScopeConfig | |||||
from .errors import GitError | from .errors import GitError | ||||
logger = get_logger() | logger = get_logger() | ||||
@@ -37,7 +38,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
Returns: | Returns: | ||||
subprocess.CompletedProcess: the command response | subprocess.CompletedProcess: the command response | ||||
""" | """ | ||||
logger.info(' '.join(args)) | |||||
logger.debug(' '.join(args)) | |||||
response = subprocess.run( | response = subprocess.run( | ||||
[self.git_path, *args], | [self.git_path, *args], | ||||
stdout=subprocess.PIPE, | stdout=subprocess.PIPE, | ||||
@@ -50,6 +51,15 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
'stdout: %s, stderr: %s' % | 'stdout: %s, stderr: %s' % | ||||
(response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | (response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | ||||
def config_auth_token(self, repo_dir, auth_token): | |||||
url = self.get_repo_remote_url(repo_dir) | |||||
if '//oauth2' not in url: | |||||
auth_url = self._add_token(auth_token, url) | |||||
cmd_args = '-C %s remote set-url origin %s' % (repo_dir, auth_url) | |||||
cmd_args = cmd_args.split(' ') | |||||
rsp = self._run_git_command(*cmd_args) | |||||
logger.debug(rsp.stdout.decode('utf8')) | |||||
def _add_token(self, token: str, url: str): | def _add_token(self, token: str, url: str): | ||||
if token: | if token: | ||||
if '//oauth2' not in url: | if '//oauth2' not in url: | ||||
@@ -104,9 +114,23 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
logger.debug(clone_args) | logger.debug(clone_args) | ||||
clone_args = clone_args.split(' ') | clone_args = clone_args.split(' ') | ||||
response = self._run_git_command(*clone_args) | response = self._run_git_command(*clone_args) | ||||
logger.info(response.stdout.decode('utf8')) | |||||
logger.debug(response.stdout.decode('utf8')) | |||||
return response | return response | ||||
def add_user_info(self, repo_base_dir, repo_name): | |||||
user_name, user_email = ModelScopeConfig.get_user_info() | |||||
if user_name and user_email: | |||||
# config user.name and user.email if exist | |||||
config_user_name_args = '-C %s/%s config user.name %s' % ( | |||||
repo_base_dir, repo_name, user_name) | |||||
response = self._run_git_command(*config_user_name_args.split(' ')) | |||||
logger.debug(response.stdout.decode('utf8')) | |||||
config_user_email_args = '-C %s/%s config user.name %s' % ( | |||||
repo_base_dir, repo_name, user_name) | |||||
response = self._run_git_command( | |||||
*config_user_email_args.split(' ')) | |||||
logger.debug(response.stdout.decode('utf8')) | |||||
def add(self, | def add(self, | ||||
repo_dir: str, | repo_dir: str, | ||||
files: List[str] = list(), | files: List[str] = list(), | ||||
@@ -118,7 +142,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
add_args = '-C %s add %s' % (repo_dir, files_str) | add_args = '-C %s add %s' % (repo_dir, files_str) | ||||
add_args = add_args.split(' ') | add_args = add_args.split(' ') | ||||
rsp = self._run_git_command(*add_args) | rsp = self._run_git_command(*add_args) | ||||
logger.info(rsp.stdout.decode('utf8')) | |||||
logger.debug(rsp.stdout.decode('utf8')) | |||||
return rsp | return rsp | ||||
def commit(self, repo_dir: str, message: str): | def commit(self, repo_dir: str, message: str): | ||||
@@ -159,7 +183,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
push_args += ' -f' | push_args += ' -f' | ||||
push_args = push_args.split(' ') | push_args = push_args.split(' ') | ||||
rsp = self._run_git_command(*push_args) | rsp = self._run_git_command(*push_args) | ||||
logger.info(rsp.stdout.decode('utf8')) | |||||
logger.debug(rsp.stdout.decode('utf8')) | |||||
return rsp | return rsp | ||||
def get_repo_remote_url(self, repo_dir: str): | def get_repo_remote_url(self, repo_dir: str): | ||||
@@ -1,7 +1,7 @@ | |||||
import os | import os | ||||
from typing import Optional | from typing import Optional | ||||
from modelscope.hub.errors import GitError, InvalidParameter | |||||
from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .api import ModelScopeConfig | from .api import ModelScopeConfig | ||||
@@ -64,6 +64,12 @@ class Repository: | |||||
if git_wrapper.is_lfs_installed(): | if git_wrapper.is_lfs_installed(): | ||||
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | ||||
# add user info if login | |||||
self.git_wrapper.add_user_info(self.model_base_dir, | |||||
self.model_repo_name) | |||||
if self.auth_token: # config remote with auth token | |||||
self.git_wrapper.config_auth_token(self.model_dir, self.auth_token) | |||||
def _get_model_id_url(self, model_id): | def _get_model_id_url(self, model_id): | ||||
url = f'{get_endpoint()}/{model_id}.git' | url = f'{get_endpoint()}/{model_id}.git' | ||||
return url | return url | ||||
@@ -93,6 +99,14 @@ class Repository: | |||||
raise InvalidParameter(msg) | raise InvalidParameter(msg) | ||||
if not isinstance(force, bool): | if not isinstance(force, bool): | ||||
raise InvalidParameter('force must be bool') | raise InvalidParameter('force must be bool') | ||||
if not self.auth_token: | |||||
raise NotLoginException('Must login to push, please login first.') | |||||
self.git_wrapper.config_auth_token(self.model_dir, self.auth_token) | |||||
self.git_wrapper.add_user_info(self.model_base_dir, | |||||
self.model_repo_name) | |||||
url = self.git_wrapper.get_repo_remote_url(self.model_dir) | url = self.git_wrapper.get_repo_remote_url(self.model_dir) | ||||
self.git_wrapper.pull(self.model_dir) | self.git_wrapper.pull(self.model_dir) | ||||
self.git_wrapper.add(self.model_dir, all_files=True) | self.git_wrapper.add(self.model_dir, all_files=True) | ||||
@@ -1,5 +1,4 @@ | |||||
import os | import os | ||||
import tempfile | |||||
from pathlib import Path | from pathlib import Path | ||||
from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
@@ -4,15 +4,14 @@ from modelscope.hub.api import HubApi | |||||
from modelscope.utils.hub import create_model_if_not_exist | from modelscope.utils.hub import create_model_if_not_exist | ||||
# note this is temporary before official account management is ready | # note this is temporary before official account management is ready | ||||
USER_NAME = 'maasadmin' | |||||
PASSWORD = '12345678' | |||||
YOUR_ACCESS_TOKEN = 'token' | |||||
class HubExampleTest(unittest.TestCase): | class HubExampleTest(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.api = HubApi() | self.api = HubApi() | ||||
self.api.login(USER_NAME, PASSWORD) | |||||
self.api.login(YOUR_ACCESS_TOKEN) | |||||
@unittest.skip('to be used for local test only') | @unittest.skip('to be used for local test only') | ||||
def test_example_model_creation(self): | def test_example_model_creation(self): | ||||
@@ -12,20 +12,24 @@ from modelscope.hub.constants import Licenses, ModelVisibility | |||||
from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||||
TEST_PASSWORD, TEST_USER_NAME1) | |||||
from modelscope.utils.constant import ModelFile | |||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||||
TEST_MODEL_ORG) | |||||
DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
download_model_file_name = 'test.bin' | download_model_file_name = 'test.bin' | ||||
@unittest.skip( | |||||
"Access token is always change, we can't login with same access token, so skip!" | |||||
) | |||||
class HubOperationTest(unittest.TestCase): | class HubOperationTest(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.api = HubApi() | self.api = HubApi() | ||||
# note this is temporary before official account management is ready | # note this is temporary before official account management is ready | ||||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.api.login(TEST_ACCESS_TOKEN1) | |||||
self.model_name = uuid.uuid4().hex | self.model_name = uuid.uuid4().hex | ||||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
self.api.create_model( | self.api.create_model( | ||||
@@ -92,7 +96,7 @@ class HubOperationTest(unittest.TestCase): | |||||
file_path=download_model_file_name, | file_path=download_model_file_name, | ||||
cache_dir=temporary_dir) | cache_dir=temporary_dir) | ||||
assert os.path.exists(downloaded_file) | assert os.path.exists(downloaded_file) | ||||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.api.login(TEST_ACCESS_TOKEN1) | |||||
def test_snapshot_delete_download_cache_file(self): | def test_snapshot_delete_download_cache_file(self): | ||||
snapshot_path = snapshot_download(model_id=self.model_id) | snapshot_path = snapshot_download(model_id=self.model_id) | ||||
@@ -102,7 +106,7 @@ class HubOperationTest(unittest.TestCase): | |||||
os.remove(downloaded_file_path) | os.remove(downloaded_file_path) | ||||
# download again in cache | # download again in cache | ||||
file_download_path = model_file_download( | file_download_path = model_file_download( | ||||
model_id=self.model_id, file_path='README.md') | |||||
model_id=self.model_id, file_path=ModelFile.README) | |||||
assert os.path.exists(file_download_path) | assert os.path.exists(file_download_path) | ||||
# deleted file need download again | # deleted file need download again | ||||
file_download_path = model_file_download( | file_download_path = model_file_download( | ||||
@@ -13,18 +13,21 @@ from modelscope.hub.file_download import model_file_download | |||||
from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||||
TEST_PASSWORD, TEST_USER_NAME1, TEST_USER_NAME2, | |||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||||
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||||
delete_credential) | delete_credential) | ||||
@unittest.skip( | |||||
"Access token is always change, we can't login with same access token, so skip!" | |||||
) | |||||
class HubPrivateFileDownloadTest(unittest.TestCase): | class HubPrivateFileDownloadTest(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.old_cwd = os.getcwd() | self.old_cwd = os.getcwd() | ||||
self.api = HubApi() | self.api = HubApi() | ||||
# note this is temporary before official account management is ready | # note this is temporary before official account management is ready | ||||
self.token, _ = self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | |||||
self.model_name = uuid.uuid4().hex | self.model_name = uuid.uuid4().hex | ||||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
self.api.create_model( | self.api.create_model( | ||||
@@ -37,7 +40,7 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||||
def tearDown(self): | def tearDown(self): | ||||
# credential may deleted or switch login name, we need re-login here | # credential may deleted or switch login name, we need re-login here | ||||
# to ensure the temporary model is deleted. | # to ensure the temporary model is deleted. | ||||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.api.login(TEST_ACCESS_TOKEN1) | |||||
os.chdir(self.old_cwd) | os.chdir(self.old_cwd) | ||||
self.api.delete_model(model_id=self.model_id) | self.api.delete_model(model_id=self.model_id) | ||||
@@ -46,7 +49,7 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||||
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | ||||
def test_snapshot_download_private_model_no_permission(self): | def test_snapshot_download_private_model_no_permission(self): | ||||
self.token, _ = self.api.login(TEST_USER_NAME2, TEST_PASSWORD) | |||||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||||
with self.assertRaises(HTTPError): | with self.assertRaises(HTTPError): | ||||
snapshot_download(self.model_id) | snapshot_download(self.model_id) | ||||
@@ -60,7 +63,7 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||||
assert os.path.exists(file_path) | assert os.path.exists(file_path) | ||||
def test_download_file_private_model_no_permission(self): | def test_download_file_private_model_no_permission(self): | ||||
self.token, _ = self.api.login(TEST_USER_NAME2, TEST_PASSWORD) | |||||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||||
with self.assertRaises(HTTPError): | with self.assertRaises(HTTPError): | ||||
model_file_download(self.model_id, ModelFile.README) | model_file_download(self.model_id, ModelFile.README) | ||||
@@ -8,19 +8,23 @@ from modelscope.hub.api import HubApi | |||||
from modelscope.hub.constants import Licenses, ModelVisibility | from modelscope.hub.constants import Licenses, ModelVisibility | ||||
from modelscope.hub.errors import GitError | from modelscope.hub.errors import GitError | ||||
from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||||
TEST_PASSWORD, TEST_USER_NAME1, TEST_USER_NAME2) | |||||
from modelscope.utils.constant import ModelFile | |||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||||
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG) | |||||
DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
@unittest.skip( | |||||
"Access token is always change, we can't login with same access token, so skip!" | |||||
) | |||||
class HubPrivateRepositoryTest(unittest.TestCase): | class HubPrivateRepositoryTest(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.old_cwd = os.getcwd() | self.old_cwd = os.getcwd() | ||||
self.api = HubApi() | self.api = HubApi() | ||||
# note this is temporary before official account management is ready | # note this is temporary before official account management is ready | ||||
self.token, _ = self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | |||||
self.model_name = uuid.uuid4().hex | self.model_name = uuid.uuid4().hex | ||||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
self.api.create_model( | self.api.create_model( | ||||
@@ -31,27 +35,25 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||||
) | ) | ||||
def tearDown(self): | def tearDown(self): | ||||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.api.login(TEST_ACCESS_TOKEN1) | |||||
os.chdir(self.old_cwd) | os.chdir(self.old_cwd) | ||||
self.api.delete_model(model_id=self.model_id) | self.api.delete_model(model_id=self.model_id) | ||||
def test_clone_private_repo_no_permission(self): | def test_clone_private_repo_no_permission(self): | ||||
token, _ = self.api.login(TEST_USER_NAME2, TEST_PASSWORD) | |||||
token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||||
temporary_dir = tempfile.mkdtemp() | temporary_dir = tempfile.mkdtemp() | ||||
local_dir = os.path.join(temporary_dir, self.model_name) | local_dir = os.path.join(temporary_dir, self.model_name) | ||||
with self.assertRaises(GitError) as cm: | with self.assertRaises(GitError) as cm: | ||||
Repository(local_dir, clone_from=self.model_id, auth_token=token) | Repository(local_dir, clone_from=self.model_id, auth_token=token) | ||||
print(cm.exception) | print(cm.exception) | ||||
assert not os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
assert not os.path.exists(os.path.join(local_dir, ModelFile.README)) | |||||
def test_clone_private_repo_has_permission(self): | def test_clone_private_repo_has_permission(self): | ||||
temporary_dir = tempfile.mkdtemp() | temporary_dir = tempfile.mkdtemp() | ||||
local_dir = os.path.join(temporary_dir, self.model_name) | local_dir = os.path.join(temporary_dir, self.model_name) | ||||
repo1 = Repository( | |||||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||||
print(repo1.model_dir) | |||||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
Repository(local_dir, clone_from=self.model_id, auth_token=self.token) | |||||
assert os.path.exists(os.path.join(local_dir, ModelFile.README)) | |||||
def test_initlize_repo_multiple_times(self): | def test_initlize_repo_multiple_times(self): | ||||
temporary_dir = tempfile.mkdtemp() | temporary_dir = tempfile.mkdtemp() | ||||
@@ -59,7 +61,7 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||||
repo1 = Repository( | repo1 = Repository( | ||||
local_dir, clone_from=self.model_id, auth_token=self.token) | local_dir, clone_from=self.model_id, auth_token=self.token) | ||||
print(repo1.model_dir) | print(repo1.model_dir) | ||||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
assert os.path.exists(os.path.join(local_dir, ModelFile.README)) | |||||
repo2 = Repository( | repo2 = Repository( | ||||
local_dir, clone_from=self.model_id, | local_dir, clone_from=self.model_id, | ||||
auth_token=self.token) # skip clone | auth_token=self.token) # skip clone | ||||
@@ -14,23 +14,26 @@ from modelscope.hub.errors import NotExistError | |||||
from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
from modelscope.hub.git import GitCommandWrapper | from modelscope.hub.git import GitCommandWrapper | ||||
from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||||
TEST_PASSWORD, TEST_USER_NAME1, TEST_USER_NAME2, | |||||
delete_credential, delete_stored_git_credential) | |||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||||
TEST_MODEL_ORG, delete_credential) | |||||
logger = get_logger() | logger = get_logger() | ||||
logger.setLevel('DEBUG') | logger.setLevel('DEBUG') | ||||
DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
@unittest.skip( | |||||
"Access token is always change, we can't login with same access token, so skip!" | |||||
) | |||||
class HubRepositoryTest(unittest.TestCase): | class HubRepositoryTest(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.old_cwd = os.getcwd() | self.old_cwd = os.getcwd() | ||||
self.api = HubApi() | self.api = HubApi() | ||||
# note this is temporary before official account management is ready | # note this is temporary before official account management is ready | ||||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||||
self.api.login(TEST_ACCESS_TOKEN1) | |||||
self.model_name = uuid.uuid4().hex | self.model_name = uuid.uuid4().hex | ||||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
self.api.create_model( | self.api.create_model( | ||||
@@ -48,18 +51,17 @@ class HubRepositoryTest(unittest.TestCase): | |||||
def test_clone_repo(self): | def test_clone_repo(self): | ||||
Repository(self.model_dir, clone_from=self.model_id) | Repository(self.model_dir, clone_from=self.model_id) | ||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) | |||||
def test_clone_public_model_without_token(self): | def test_clone_public_model_without_token(self): | ||||
delete_credential() | delete_credential() | ||||
delete_stored_git_credential(TEST_USER_NAME1) | |||||
Repository(self.model_dir, clone_from=self.model_id) | Repository(self.model_dir, clone_from=self.model_id) | ||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) # re-login for delete | |||||
assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) | |||||
self.api.login(TEST_ACCESS_TOKEN1) # re-login for delete | |||||
def test_push_all(self): | def test_push_all(self): | ||||
repo = Repository(self.model_dir, clone_from=self.model_id) | repo = Repository(self.model_dir, clone_from=self.model_id) | ||||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) | |||||
os.chdir(self.model_dir) | os.chdir(self.model_dir) | ||||
lfs_file1 = 'test1.bin' | lfs_file1 = 'test1.bin' | ||||
lfs_file2 = 'test2.bin' | lfs_file2 = 'test2.bin' | ||||
@@ -3,25 +3,16 @@ import shutil | |||||
from codecs import ignore_errors | from codecs import ignore_errors | ||||
from os.path import expanduser | from os.path import expanduser | ||||
TEST_USER_NAME1 = 'citest' | |||||
TEST_USER_NAME2 = 'sdkdev' | |||||
TEST_PASSWORD = '12345678' | |||||
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH | |||||
# for user citest and sdkdev | |||||
TEST_ACCESS_TOKEN1 = 'OVAzNU9aZ2FYbXFhdGNzZll6VHRtalQ0T1BpZTNGeWVhMkxSSGpTSzU0dkM5WE5ObDFKdFRQWGc2U2ZIdjdPdg==' | |||||
TEST_ACCESS_TOKEN2 = 'aXRocHhGeG0rNXRWQWhBSnJpTTZUQ0RDbUlkcUJRS1dQR2lNb0xIa0JjRDBrT1JKYklZV05DVzROTTdtamxWcg==' | |||||
TEST_MODEL_CHINESE_NAME = '内部测试模型' | TEST_MODEL_CHINESE_NAME = '内部测试模型' | ||||
TEST_MODEL_ORG = 'citest' | TEST_MODEL_ORG = 'citest' | ||||
def delete_credential(): | def delete_credential(): | ||||
path_credential = expanduser('~/.modelscope/credentials') | |||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | |||||
shutil.rmtree(path_credential, ignore_errors=True) | shutil.rmtree(path_credential, ignore_errors=True) | ||||
def delete_stored_git_credential(user): | |||||
credential_path = expanduser('~/.git-credentials') | |||||
if os.path.exists(credential_path): | |||||
with open(credential_path, 'r+') as f: | |||||
lines = f.readlines() | |||||
lines = [line for line in lines if user not in line] | |||||
f.seek(0) | |||||
f.write(''.join(lines)) | |||||
f.truncate() |