@@ -24,6 +24,8 @@ do | |||
-v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
-e CI_TEST=True \ | |||
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
--workdir=$CODE_DIR_IN_CONTAINER \ | |||
--net host \ | |||
${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
@@ -3,12 +3,19 @@ import pickle | |||
import shutil | |||
import subprocess | |||
from collections import defaultdict | |||
from http import HTTPStatus | |||
from http.cookiejar import CookieJar | |||
from os.path import expanduser | |||
from typing import List, Optional, Tuple, Union | |||
import requests | |||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
API_RESPONSE_FIELD_EMAIL, | |||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||
API_RESPONSE_FIELD_MESSAGE, | |||
API_RESPONSE_FIELD_USERNAME, | |||
DEFAULT_CREDENTIALS_PATH) | |||
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, | |||
HUB_DATASET_ENDPOINT) | |||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
@@ -32,16 +39,13 @@ class HubApi: | |||
def login( | |||
self, | |||
user_name: str, | |||
password: str, | |||
access_token: str, | |||
) -> tuple(): | |||
""" | |||
Login with username and password | |||
Args: | |||
user_name(`str`): user name on modelscope | |||
password(`str`): password | |||
access_token(`str`): user access token on modelscope. | |||
Returns: | |||
cookies: to authenticate yourself to ModelScope open-api | |||
gitlab token: to access private repos | |||
@@ -51,24 +55,23 @@ class HubApi: | |||
</Tip> | |||
""" | |||
path = f'{self.endpoint}/api/v1/login' | |||
r = requests.post( | |||
path, json={ | |||
'username': user_name, | |||
'password': password | |||
}) | |||
r = requests.post(path, json={'AccessToken': access_token}) | |||
r.raise_for_status() | |||
d = r.json() | |||
raise_on_error(d) | |||
token = d['Data']['AccessToken'] | |||
token = d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_GIT_ACCESS_TOKEN] | |||
cookies = r.cookies | |||
# save token and cookie | |||
ModelScopeConfig.save_token(token) | |||
ModelScopeConfig.save_cookies(cookies) | |||
ModelScopeConfig.write_to_git_credential(user_name, password) | |||
ModelScopeConfig.save_user_info( | |||
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_USERNAME], | |||
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_EMAIL]) | |||
return d['Data']['AccessToken'], cookies | |||
return d[API_RESPONSE_FIELD_DATA][ | |||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies | |||
def create_model( | |||
self, | |||
@@ -161,11 +164,11 @@ class HubApi: | |||
r = requests.get(path, cookies=cookies) | |||
handle_http_response(r, logger, cookies, model_id) | |||
if r.status_code == 200: | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
return r.json()['Data'] | |||
return r.json()[API_RESPONSE_FIELD_DATA] | |||
else: | |||
raise NotExistError(r.json()['Message']) | |||
raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
@@ -189,12 +192,12 @@ class HubApi: | |||
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | |||
(owner_or_group, page_number, page_size)) | |||
handle_http_response(r, logger, cookies, 'list_model') | |||
if r.status_code == 200: | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()['Data'] | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()['Message']) | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
return None | |||
@@ -232,7 +235,7 @@ class HubApi: | |||
handle_http_response(r, logger, cookies, model_id) | |||
d = r.json() | |||
raise_on_error(d) | |||
info = d['Data'] | |||
info = d[API_RESPONSE_FIELD_DATA] | |||
branches = [x['Revision'] for x in info['RevisionMap']['Branches'] | |||
] if info['RevisionMap']['Branches'] else [] | |||
tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||
@@ -276,7 +279,7 @@ class HubApi: | |||
raise_on_error(d) | |||
files = [] | |||
for file in d['Data']['Files']: | |||
for file in d[API_RESPONSE_FIELD_DATA]['Files']: | |||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': | |||
continue | |||
@@ -289,7 +292,7 @@ class HubApi: | |||
params = {} | |||
r = requests.get(path, params=params, headers=headers) | |||
r.raise_for_status() | |||
dataset_list = r.json()['Data'] | |||
dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | |||
return [x['Name'] for x in dataset_list] | |||
def fetch_dataset_scripts( | |||
@@ -379,21 +382,27 @@ class HubApi: | |||
class ModelScopeConfig: | |||
path_credential = expanduser('~/.modelscope/credentials') | |||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | |||
COOKIES_FILE_NAME = 'cookies' | |||
GIT_TOKEN_FILE_NAME = 'git_token' | |||
USER_INFO_FILE_NAME = 'user' | |||
@classmethod | |||
def make_sure_credential_path_exist(cls): | |||
os.makedirs(cls.path_credential, exist_ok=True) | |||
@staticmethod | |||
def make_sure_credential_path_exist(): | |||
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True) | |||
@classmethod | |||
def save_cookies(cls, cookies: CookieJar): | |||
cls.make_sure_credential_path_exist() | |||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | |||
@staticmethod | |||
def save_cookies(cookies: CookieJar): | |||
ModelScopeConfig.make_sure_credential_path_exist() | |||
with open( | |||
os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.COOKIES_FILE_NAME), 'wb+') as f: | |||
pickle.dump(cookies, f) | |||
@classmethod | |||
def get_cookies(cls): | |||
cookies_path = os.path.join(cls.path_credential, 'cookies') | |||
@staticmethod | |||
def get_cookies(): | |||
cookies_path = os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.COOKIES_FILE_NAME) | |||
if os.path.exists(cookies_path): | |||
with open(cookies_path, 'rb') as f: | |||
cookies = pickle.load(f) | |||
@@ -405,14 +414,38 @@ class ModelScopeConfig: | |||
return cookies | |||
return None | |||
@classmethod | |||
def save_token(cls, token: str): | |||
cls.make_sure_credential_path_exist() | |||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | |||
@staticmethod | |||
def save_token(token: str): | |||
ModelScopeConfig.make_sure_credential_path_exist() | |||
with open( | |||
os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.GITLAB_TOKEN_FILE_NAME), | |||
'w+') as f: | |||
f.write(token) | |||
@classmethod | |||
def get_token(cls) -> Optional[str]: | |||
@staticmethod | |||
def save_user_info(user_name: str, user_email: str): | |||
ModelScopeConfig.make_sure_credential_path_exist() | |||
with open( | |||
os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.USER_INFO_FILE_NAME), 'w+') as f: | |||
f.write('%s:%s' % (user_name, user_email)) | |||
@staticmethod | |||
def get_user_info() -> Tuple[str, str]: | |||
try: | |||
with open( | |||
os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.USER_INFO_FILE_NAME), | |||
'r') as f: | |||
info = f.read() | |||
return info.split(':')[0], info.split(':')[1] | |||
except FileNotFoundError: | |||
pass | |||
return None, None | |||
@staticmethod | |||
def get_token() -> Optional[str]: | |||
""" | |||
Get token or None if not existent. | |||
@@ -422,24 +455,11 @@ class ModelScopeConfig: | |||
""" | |||
token = None | |||
try: | |||
with open(os.path.join(cls.path_credential, 'token'), 'r') as f: | |||
with open( | |||
os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.GITLAB_TOKEN_FILE_NAME), | |||
'r') as f: | |||
token = f.read() | |||
except FileNotFoundError: | |||
pass | |||
return token | |||
@staticmethod | |||
def write_to_git_credential(username: str, password: str): | |||
with subprocess.Popen( | |||
'git credential-store store'.split(), | |||
stdin=subprocess.PIPE, | |||
stdout=subprocess.PIPE, | |||
stderr=subprocess.STDOUT, | |||
) as process: | |||
input_username = f'username={username.lower()}' | |||
input_password = f'password={password}' | |||
process.stdin.write( | |||
f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n' | |||
.encode('utf-8')) | |||
process.stdin.flush() |
@@ -1,12 +1,17 @@ | |||
MODELSCOPE_URL_SCHEME = 'http://' | |||
DEFAULT_MODELSCOPE_IP = '123.57.147.185' | |||
DEFAULT_MODELSCOPE_DOMAIN = DEFAULT_MODELSCOPE_IP + ':31090' | |||
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_IP + ':31090' | |||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn' | |||
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN | |||
DEFAULT_MODELSCOPE_GROUP = 'damo' | |||
MODEL_ID_SEPARATOR = '/' | |||
LOGGER_NAME = 'ModelScopeHub' | |||
DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | |||
API_RESPONSE_FIELD_DATA = 'Data' | |||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||
API_RESPONSE_FIELD_USERNAME = 'Username' | |||
API_RESPONSE_FIELD_EMAIL = 'Email' | |||
API_RESPONSE_FIELD_MESSAGE = 'Message' | |||
class Licenses(object): | |||
@@ -1,3 +1,5 @@ | |||
from http import HTTPStatus | |||
from requests.exceptions import HTTPError | |||
@@ -17,6 +19,10 @@ class InvalidParameter(Exception): | |||
pass | |||
class NotLoginException(Exception): | |||
pass | |||
def is_ok(rsp): | |||
""" Check the request is ok | |||
@@ -26,7 +32,7 @@ def is_ok(rsp): | |||
'RequestId': '', 'Success': False} | |||
Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True} | |||
""" | |||
return rsp['Code'] == 200 and rsp['Success'] | |||
return rsp['Code'] == HTTPStatus.OK and rsp['Success'] | |||
def handle_http_response(response, logger, cookies, model_id): | |||
@@ -46,7 +52,7 @@ def raise_on_error(rsp): | |||
Args: | |||
rsp (_type_): The server response | |||
""" | |||
if rsp['Code'] == 200 and rsp['Success']: | |||
if rsp['Code'] == HTTPStatus.OK and rsp['Success']: | |||
return True | |||
else: | |||
raise RequestError(rsp['Message']) | |||
@@ -59,7 +65,7 @@ def datahub_raise_on_error(url, rsp): | |||
Args: | |||
rsp (_type_): The server response | |||
""" | |||
if rsp.get('Code') == 200: | |||
if rsp.get('Code') == HTTPStatus.OK: | |||
return True | |||
else: | |||
raise RequestError( | |||
@@ -1,30 +1,22 @@ | |||
import copy | |||
import fnmatch | |||
import logging | |||
import os | |||
import sys | |||
import tempfile | |||
import time | |||
from functools import partial | |||
from hashlib import sha256 | |||
from http.cookiejar import CookieJar | |||
from pathlib import Path | |||
from typing import BinaryIO, Dict, Optional, Union | |||
from typing import Dict, Optional, Union | |||
from uuid import uuid4 | |||
import json | |||
import requests | |||
from filelock import FileLock | |||
from requests.exceptions import HTTPError | |||
from tqdm import tqdm | |||
from modelscope import __version__ | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME, | |||
MODEL_ID_SEPARATOR) | |||
from .errors import NotExistError, RequestError, raise_on_error | |||
from .errors import NotExistError | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import (get_cache_dir, get_endpoint, | |||
model_id_to_group_owner_name) | |||
@@ -4,6 +4,7 @@ from typing import List | |||
from xmlrpc.client import Boolean | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from .errors import GitError | |||
logger = get_logger() | |||
@@ -37,7 +38,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
Returns: | |||
subprocess.CompletedProcess: the command response | |||
""" | |||
logger.info(' '.join(args)) | |||
logger.debug(' '.join(args)) | |||
response = subprocess.run( | |||
[self.git_path, *args], | |||
stdout=subprocess.PIPE, | |||
@@ -50,6 +51,15 @@ class GitCommandWrapper(metaclass=Singleton): | |||
'stdout: %s, stderr: %s' % | |||
(response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | |||
def config_auth_token(self, repo_dir, auth_token): | |||
url = self.get_repo_remote_url(repo_dir) | |||
if '//oauth2' not in url: | |||
auth_url = self._add_token(auth_token, url) | |||
cmd_args = '-C %s remote set-url origin %s' % (repo_dir, auth_url) | |||
cmd_args = cmd_args.split(' ') | |||
rsp = self._run_git_command(*cmd_args) | |||
logger.debug(rsp.stdout.decode('utf8')) | |||
def _add_token(self, token: str, url: str): | |||
if token: | |||
if '//oauth2' not in url: | |||
@@ -104,9 +114,23 @@ class GitCommandWrapper(metaclass=Singleton): | |||
logger.debug(clone_args) | |||
clone_args = clone_args.split(' ') | |||
response = self._run_git_command(*clone_args) | |||
logger.info(response.stdout.decode('utf8')) | |||
logger.debug(response.stdout.decode('utf8')) | |||
return response | |||
def add_user_info(self, repo_base_dir, repo_name): | |||
user_name, user_email = ModelScopeConfig.get_user_info() | |||
if user_name and user_email: | |||
# config user.name and user.email if exist | |||
config_user_name_args = '-C %s/%s config user.name %s' % ( | |||
repo_base_dir, repo_name, user_name) | |||
response = self._run_git_command(*config_user_name_args.split(' ')) | |||
logger.debug(response.stdout.decode('utf8')) | |||
config_user_email_args = '-C %s/%s config user.name %s' % ( | |||
repo_base_dir, repo_name, user_name) | |||
response = self._run_git_command( | |||
*config_user_email_args.split(' ')) | |||
logger.debug(response.stdout.decode('utf8')) | |||
def add(self, | |||
repo_dir: str, | |||
files: List[str] = list(), | |||
@@ -118,7 +142,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
add_args = '-C %s add %s' % (repo_dir, files_str) | |||
add_args = add_args.split(' ') | |||
rsp = self._run_git_command(*add_args) | |||
logger.info(rsp.stdout.decode('utf8')) | |||
logger.debug(rsp.stdout.decode('utf8')) | |||
return rsp | |||
def commit(self, repo_dir: str, message: str): | |||
@@ -159,7 +183,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
push_args += ' -f' | |||
push_args = push_args.split(' ') | |||
rsp = self._run_git_command(*push_args) | |||
logger.info(rsp.stdout.decode('utf8')) | |||
logger.debug(rsp.stdout.decode('utf8')) | |||
return rsp | |||
def get_repo_remote_url(self, repo_dir: str): | |||
@@ -1,7 +1,7 @@ | |||
import os | |||
from typing import Optional | |||
from modelscope.hub.errors import GitError, InvalidParameter | |||
from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
@@ -64,6 +64,12 @@ class Repository: | |||
if git_wrapper.is_lfs_installed(): | |||
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | |||
# add user info if login | |||
self.git_wrapper.add_user_info(self.model_base_dir, | |||
self.model_repo_name) | |||
if self.auth_token: # config remote with auth token | |||
self.git_wrapper.config_auth_token(self.model_dir, self.auth_token) | |||
def _get_model_id_url(self, model_id): | |||
url = f'{get_endpoint()}/{model_id}.git' | |||
return url | |||
@@ -93,6 +99,14 @@ class Repository: | |||
raise InvalidParameter(msg) | |||
if not isinstance(force, bool): | |||
raise InvalidParameter('force must be bool') | |||
if not self.auth_token: | |||
raise NotLoginException('Must login to push, please login first.') | |||
self.git_wrapper.config_auth_token(self.model_dir, self.auth_token) | |||
self.git_wrapper.add_user_info(self.model_base_dir, | |||
self.model_repo_name) | |||
url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
self.git_wrapper.pull(self.model_dir) | |||
self.git_wrapper.add(self.model_dir, all_files=True) | |||
@@ -1,5 +1,4 @@ | |||
import os | |||
import tempfile | |||
from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
@@ -4,15 +4,14 @@ from modelscope.hub.api import HubApi | |||
from modelscope.utils.hub import create_model_if_not_exist | |||
# note this is temporary before official account management is ready | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
YOUR_ACCESS_TOKEN = 'token' | |||
class HubExampleTest(unittest.TestCase): | |||
def setUp(self): | |||
self.api = HubApi() | |||
self.api.login(USER_NAME, PASSWORD) | |||
self.api.login(YOUR_ACCESS_TOKEN) | |||
@unittest.skip('to be used for local test only') | |||
def test_example_model_creation(self): | |||
@@ -12,20 +12,24 @@ from modelscope.hub.constants import Licenses, ModelVisibility | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.repository import Repository | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||
TEST_PASSWORD, TEST_USER_NAME1) | |||
from modelscope.utils.constant import ModelFile | |||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||
TEST_MODEL_ORG) | |||
DEFAULT_GIT_PATH = 'git' | |||
download_model_file_name = 'test.bin' | |||
@unittest.skip( | |||
"Access token is always change, we can't login with same access token, so skip!" | |||
) | |||
class HubOperationTest(unittest.TestCase): | |||
def setUp(self): | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
self.api.create_model( | |||
@@ -92,7 +96,7 @@ class HubOperationTest(unittest.TestCase): | |||
file_path=download_model_file_name, | |||
cache_dir=temporary_dir) | |||
assert os.path.exists(downloaded_file) | |||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
def test_snapshot_delete_download_cache_file(self): | |||
snapshot_path = snapshot_download(model_id=self.model_id) | |||
@@ -102,7 +106,7 @@ class HubOperationTest(unittest.TestCase): | |||
os.remove(downloaded_file_path) | |||
# download again in cache | |||
file_download_path = model_file_download( | |||
model_id=self.model_id, file_path='README.md') | |||
model_id=self.model_id, file_path=ModelFile.README) | |||
assert os.path.exists(file_download_path) | |||
# deleted file need download again | |||
file_download_path = model_file_download( | |||
@@ -13,18 +13,21 @@ from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.repository import Repository | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.utils.constant import ModelFile | |||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||
TEST_PASSWORD, TEST_USER_NAME1, TEST_USER_NAME2, | |||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||
delete_credential) | |||
@unittest.skip( | |||
"Access token is always change, we can't login with same access token, so skip!" | |||
) | |||
class HubPrivateFileDownloadTest(unittest.TestCase): | |||
def setUp(self): | |||
self.old_cwd = os.getcwd() | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.token, _ = self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
self.api.create_model( | |||
@@ -37,7 +40,7 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||
def tearDown(self): | |||
# credential may deleted or switch login name, we need re-login here | |||
# to ensure the temporary model is deleted. | |||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
os.chdir(self.old_cwd) | |||
self.api.delete_model(model_id=self.model_id) | |||
@@ -46,7 +49,7 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||
def test_snapshot_download_private_model_no_permission(self): | |||
self.token, _ = self.api.login(TEST_USER_NAME2, TEST_PASSWORD) | |||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||
with self.assertRaises(HTTPError): | |||
snapshot_download(self.model_id) | |||
@@ -60,7 +63,7 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||
assert os.path.exists(file_path) | |||
def test_download_file_private_model_no_permission(self): | |||
self.token, _ = self.api.login(TEST_USER_NAME2, TEST_PASSWORD) | |||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||
with self.assertRaises(HTTPError): | |||
model_file_download(self.model_id, ModelFile.README) | |||
@@ -8,19 +8,23 @@ from modelscope.hub.api import HubApi | |||
from modelscope.hub.constants import Licenses, ModelVisibility | |||
from modelscope.hub.errors import GitError | |||
from modelscope.hub.repository import Repository | |||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||
TEST_PASSWORD, TEST_USER_NAME1, TEST_USER_NAME2) | |||
from modelscope.utils.constant import ModelFile | |||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG) | |||
DEFAULT_GIT_PATH = 'git' | |||
@unittest.skip( | |||
"Access token is always change, we can't login with same access token, so skip!" | |||
) | |||
class HubPrivateRepositoryTest(unittest.TestCase): | |||
def setUp(self): | |||
self.old_cwd = os.getcwd() | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.token, _ = self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
self.api.create_model( | |||
@@ -31,27 +35,25 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||
) | |||
def tearDown(self): | |||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
os.chdir(self.old_cwd) | |||
self.api.delete_model(model_id=self.model_id) | |||
def test_clone_private_repo_no_permission(self): | |||
token, _ = self.api.login(TEST_USER_NAME2, TEST_PASSWORD) | |||
token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||
temporary_dir = tempfile.mkdtemp() | |||
local_dir = os.path.join(temporary_dir, self.model_name) | |||
with self.assertRaises(GitError) as cm: | |||
Repository(local_dir, clone_from=self.model_id, auth_token=token) | |||
print(cm.exception) | |||
assert not os.path.exists(os.path.join(local_dir, 'README.md')) | |||
assert not os.path.exists(os.path.join(local_dir, ModelFile.README)) | |||
def test_clone_private_repo_has_permission(self): | |||
temporary_dir = tempfile.mkdtemp() | |||
local_dir = os.path.join(temporary_dir, self.model_name) | |||
repo1 = Repository( | |||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||
print(repo1.model_dir) | |||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||
Repository(local_dir, clone_from=self.model_id, auth_token=self.token) | |||
assert os.path.exists(os.path.join(local_dir, ModelFile.README)) | |||
def test_initlize_repo_multiple_times(self): | |||
temporary_dir = tempfile.mkdtemp() | |||
@@ -59,7 +61,7 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||
repo1 = Repository( | |||
local_dir, clone_from=self.model_id, auth_token=self.token) | |||
print(repo1.model_dir) | |||
assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||
assert os.path.exists(os.path.join(local_dir, ModelFile.README)) | |||
repo2 = Repository( | |||
local_dir, clone_from=self.model_id, | |||
auth_token=self.token) # skip clone | |||
@@ -14,23 +14,26 @@ from modelscope.hub.errors import NotExistError | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.git import GitCommandWrapper | |||
from modelscope.hub.repository import Repository | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from .test_utils import (TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||
TEST_PASSWORD, TEST_USER_NAME1, TEST_USER_NAME2, | |||
delete_credential, delete_stored_git_credential) | |||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||
TEST_MODEL_ORG, delete_credential) | |||
logger = get_logger() | |||
logger.setLevel('DEBUG') | |||
DEFAULT_GIT_PATH = 'git' | |||
@unittest.skip( | |||
"Access token is always change, we can't login with same access token, so skip!" | |||
) | |||
class HubRepositoryTest(unittest.TestCase): | |||
def setUp(self): | |||
self.old_cwd = os.getcwd() | |||
self.api = HubApi() | |||
# note this is temporary before official account management is ready | |||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
self.model_name = uuid.uuid4().hex | |||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
self.api.create_model( | |||
@@ -48,18 +51,17 @@ class HubRepositoryTest(unittest.TestCase): | |||
def test_clone_repo(self): | |||
Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) | |||
def test_clone_public_model_without_token(self): | |||
delete_credential() | |||
delete_stored_git_credential(TEST_USER_NAME1) | |||
Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
self.api.login(TEST_USER_NAME1, TEST_PASSWORD) # re-login for delete | |||
assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) | |||
self.api.login(TEST_ACCESS_TOKEN1) # re-login for delete | |||
def test_push_all(self): | |||
repo = Repository(self.model_dir, clone_from=self.model_id) | |||
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) | |||
os.chdir(self.model_dir) | |||
lfs_file1 = 'test1.bin' | |||
lfs_file2 = 'test2.bin' | |||
@@ -3,25 +3,16 @@ import shutil | |||
from codecs import ignore_errors | |||
from os.path import expanduser | |||
TEST_USER_NAME1 = 'citest' | |||
TEST_USER_NAME2 = 'sdkdev' | |||
TEST_PASSWORD = '12345678' | |||
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH | |||
# for user citest and sdkdev | |||
TEST_ACCESS_TOKEN1 = 'OVAzNU9aZ2FYbXFhdGNzZll6VHRtalQ0T1BpZTNGeWVhMkxSSGpTSzU0dkM5WE5ObDFKdFRQWGc2U2ZIdjdPdg==' | |||
TEST_ACCESS_TOKEN2 = 'aXRocHhGeG0rNXRWQWhBSnJpTTZUQ0RDbUlkcUJRS1dQR2lNb0xIa0JjRDBrT1JKYklZV05DVzROTTdtamxWcg==' | |||
TEST_MODEL_CHINESE_NAME = '内部测试模型' | |||
TEST_MODEL_ORG = 'citest' | |||
def delete_credential(): | |||
path_credential = expanduser('~/.modelscope/credentials') | |||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | |||
shutil.rmtree(path_credential, ignore_errors=True) | |||
def delete_stored_git_credential(user): | |||
credential_path = expanduser('~/.git-credentials') | |||
if os.path.exists(credential_path): | |||
with open(credential_path, 'r+') as f: | |||
lines = f.readlines() | |||
lines = [line for line in lines if user not in line] | |||
f.seek(0) | |||
f.write(''.join(lines)) | |||
f.truncate() |