Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9319772master
@@ -43,7 +43,7 @@ extensions = [ | |||
'sphinx.ext.autodoc', | |||
'sphinx.ext.napoleon', | |||
'sphinx.ext.viewcode', | |||
'recommonmark', | |||
'myst_parser', | |||
'sphinx_markdown_tables', | |||
'sphinx_copybutton', | |||
] | |||
@@ -11,8 +11,8 @@ import requests | |||
from modelscope.utils.logger import get_logger | |||
from ..msdatasets.config import DOWNLOADED_DATASETS_PATH, HUB_DATASET_ENDPOINT | |||
from ..utils.constant import DownloadMode | |||
from .constants import MODELSCOPE_URL_SCHEME | |||
from ..utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, | |||
DownloadMode) | |||
from .errors import (InvalidParameter, NotExistError, datahub_raise_on_error, | |||
handle_http_response, is_ok, raise_on_error) | |||
from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||
@@ -35,7 +35,7 @@ class HubApi: | |||
Login with username and password | |||
Args: | |||
username(`str`): user name on modelscope | |||
user_name(`str`): user name on modelscope | |||
password(`str`): password | |||
Returns: | |||
@@ -135,7 +135,7 @@ class HubApi: | |||
def get_model( | |||
self, | |||
model_id: str, | |||
revision: str = 'master', | |||
revision: str = DEFAULT_MODEL_REVISION, | |||
) -> str: | |||
""" | |||
Get model information at modelscope_hub | |||
@@ -144,7 +144,7 @@ class HubApi: | |||
model_id(`str`): The model id. | |||
revision(`str`): revision of model | |||
Returns: | |||
The model details information. | |||
The model detail information. | |||
Raises: | |||
NotExistError: If the model is not exist, will throw NotExistError | |||
<Tip> | |||
@@ -207,7 +207,7 @@ class HubApi: | |||
def get_model_files(self, | |||
model_id: str, | |||
revision: Optional[str] = 'master', | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
root: Optional[str] = None, | |||
recursive: Optional[str] = False, | |||
use_cookies: Union[bool, CookieJar] = False, | |||
@@ -216,12 +216,12 @@ class HubApi: | |||
Args: | |||
model_id (str): The model id | |||
revision (Optional[str], optional): The branch or tag name. Defaults to 'master'. | |||
revision (Optional[str], optional): The branch or tag name. | |||
root (Optional[str], optional): The root path. Defaults to None. | |||
recursive (Optional[str], optional): Is recurive list files. Defaults to False. | |||
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||
recursive (Optional[str], optional): Is recursive list files. Defaults to False. | |||
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, | |||
will load cookie from local. Defaults to False. | |||
is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False. | |||
headers: request headers | |||
Raises: | |||
ValueError: If user_cookies is True, but no local cookie. | |||
@@ -258,18 +258,19 @@ class HubApi: | |||
dataset_list = r.json()['Data'] | |||
return [x['Name'] for x in dataset_list] | |||
def fetch_dataset_scripts(self, | |||
dataset_name: str, | |||
namespace: str, | |||
download_mode: Optional[DownloadMode], | |||
version: Optional[str] = 'master'): | |||
def fetch_dataset_scripts( | |||
self, | |||
dataset_name: str, | |||
namespace: str, | |||
download_mode: Optional[DownloadMode], | |||
revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
if namespace is None: | |||
raise ValueError( | |||
f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}' | |||
) | |||
version = version or 'master' | |||
revision = revision or DEFAULT_DATASET_REVISION | |||
cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name, | |||
namespace, version) | |||
namespace, revision) | |||
download_mode = DownloadMode(download_mode | |||
or DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( | |||
@@ -281,7 +282,7 @@ class HubApi: | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
dataset_id = resp['Data']['Id'] | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}' | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
r = requests.get(datahub_url) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
@@ -289,7 +290,7 @@ class HubApi: | |||
if file_list is None: | |||
raise NotExistError( | |||
f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, ' | |||
f'version = {version}] dose not exist') | |||
f'version = {revision}] dose not exist') | |||
file_list = file_list['Files'] | |||
local_paths = defaultdict(list) | |||
@@ -297,7 +298,7 @@ class HubApi: | |||
file_path = file_info['Path'] | |||
if file_path.endswith('.py'): | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \ | |||
f'Revision={version}&Path={file_path}' | |||
f'Revision={revision}&Path={file_path}' | |||
r = requests.get(datahub_url) | |||
r.raise_for_status() | |||
content = r.json()['Data']['Content'] | |||
@@ -11,8 +11,12 @@ LOGGER_NAME = 'ModelScopeHub' | |||
class Licenses(object): | |||
APACHE_V2 = 'Apache License 2.0' | |||
GPL = 'GPL' | |||
LGPL = 'LGPL' | |||
GPL_V2 = 'GPL-2.0' | |||
GPL_V3 = 'GPL-3.0' | |||
LGPL_V2_1 = 'LGPL-2.1' | |||
LGPL_V3 = 'LGPL-3.0' | |||
AFL_V3 = 'AFL-3.0' | |||
ECL_V2 = 'ECL-2.0' | |||
MIT = 'MIT' | |||
@@ -20,6 +20,7 @@ from tqdm import tqdm | |||
from modelscope import __version__ | |||
from modelscope.utils.logger import get_logger | |||
from ..utils.constant import DEFAULT_MODEL_REVISION | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME, | |||
MODEL_ID_SEPARATOR) | |||
@@ -35,7 +36,7 @@ logger = get_logger() | |||
def model_file_download( | |||
model_id: str, | |||
file_path: str, | |||
revision: Optional[str] = 'master', | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
cache_dir: Optional[str] = None, | |||
user_agent: Union[Dict, str, None] = None, | |||
local_files_only: Optional[bool] = False, | |||
@@ -55,7 +56,7 @@ def model_file_download( | |||
Path of the file to be downloaded, relative to the root of model repo | |||
revision(`str`, *optional*): | |||
revision of the model file to be downloaded. | |||
Can be any of a branch, tag or commit hash, default to `master` | |||
Can be any of a branch, tag or commit hash | |||
cache_dir (`str`, `Path`, *optional*): | |||
Path to the folder where cached files are stored. | |||
user_agent (`dict`, `str`, *optional*): | |||
@@ -120,8 +121,7 @@ def model_file_download( | |||
model_id=model_id, | |||
revision=revision, | |||
recursive=True, | |||
use_cookies=False if cookies is None else cookies, | |||
) | |||
use_cookies=False if cookies is None else cookies) | |||
for model_file in model_files: | |||
if model_file['Type'] == 'tree': | |||
@@ -3,6 +3,7 @@ from typing import Optional | |||
from modelscope.hub.errors import GitError, InvalidParameter | |||
from modelscope.utils.logger import get_logger | |||
from ..utils.constant import DEFAULT_MODEL_REVISION | |||
from .api import ModelScopeConfig | |||
from .git import GitCommandWrapper | |||
from .utils.utils import get_endpoint | |||
@@ -18,7 +19,7 @@ class Repository: | |||
self, | |||
model_dir: str, | |||
clone_from: str, | |||
revision: Optional[str] = 'master', | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
auth_token: Optional[str] = None, | |||
git_path: Optional[str] = None, | |||
): | |||
@@ -76,15 +77,15 @@ class Repository: | |||
def push(self, | |||
commit_message: str, | |||
branch: Optional[str] = 'master', | |||
force: Optional[bool] = False): | |||
branch: Optional[str] = DEFAULT_MODEL_REVISION, | |||
force: bool = False): | |||
"""Push local files to remote, this method will do. | |||
git add | |||
git commit | |||
git push | |||
Args: | |||
commit_message (str): commit message | |||
branch (Optional[str]): which branch to push. Defaults to 'master'. | |||
branch (Optional[str], optional): which branch to push. | |||
force (Optional[bool]): whether to use forced-push. | |||
""" | |||
if commit_message is None or not isinstance(commit_message, str): | |||
@@ -4,6 +4,7 @@ from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
from modelscope.utils.logger import get_logger | |||
from ..utils.constant import DEFAULT_MODEL_REVISION | |||
from .api import HubApi, ModelScopeConfig | |||
from .errors import NotExistError | |||
from .file_download import (get_file_download_url, http_get_file, | |||
@@ -15,7 +16,7 @@ logger = get_logger() | |||
def snapshot_download(model_id: str, | |||
revision: Optional[str] = 'master', | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
cache_dir: Union[str, Path, None] = None, | |||
user_agent: Optional[Union[Dict, str]] = None, | |||
local_files_only: Optional[bool] = False) -> str: | |||
@@ -7,7 +7,7 @@ from typing import Dict, Optional, Union | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.builder import build_model | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -44,7 +44,7 @@ class Model(ABC): | |||
@classmethod | |||
def from_pretrained(cls, | |||
model_name_or_path: str, | |||
revision: Optional[str] = 'master', | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
*model_args, | |||
**kwargs): | |||
""" Instantiate a model from local directory or remote model repo. Note | |||
@@ -6,7 +6,7 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.base import Model | |||
from modelscope.utils.config import ConfigDict | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Tasks | |||
from modelscope.utils.hub import read_config | |||
from modelscope.utils.registry import Registry, build_from_cfg | |||
from .base import Pipeline | |||
@@ -105,7 +105,7 @@ def pipeline(task: str = None, | |||
pipeline_name: str = None, | |||
framework: str = None, | |||
device: int = -1, | |||
model_revision: Optional[str] = 'master', | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
**kwargs) -> Pipeline: | |||
""" Factory method to build an obj:`Pipeline`. | |||
@@ -5,7 +5,7 @@ from typing import List, Optional, Union | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -21,7 +21,7 @@ def is_config_has_model(cfg_file): | |||
def is_official_hub_path(path: Union[str, List], | |||
revision: Optional[str] = 'master'): | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
""" Whether path is an official hub name or a valid local | |||
path to official hub directory. | |||
""" | |||
@@ -117,3 +117,6 @@ class Requirements(object): | |||
TENSORFLOW = 'tensorflow' | |||
PYTORCH = 'pytorch' | |||
DEFAULT_MODEL_REVISION = 'master' | |||
DEFAULT_DATASET_REVISION = 'master' |
@@ -10,7 +10,7 @@ from modelscope.hub.constants import Licenses, ModelVisibility | |||
from modelscope.hub.file_download import model_file_download | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from .logger import get_logger | |||
logger = get_logger(__name__) | |||
@@ -22,7 +22,7 @@ def create_model_if_not_exist( | |||
chinese_name: str, | |||
visibility: Optional[int] = ModelVisibility.PUBLIC, | |||
license: Optional[str] = Licenses.APACHE_V2, | |||
revision: Optional[str] = 'master'): | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
exists = True | |||
try: | |||
api.get_model(model_id=model_id, revision=revision) | |||
@@ -42,12 +42,13 @@ def create_model_if_not_exist( | |||
return True | |||
def read_config(model_id_or_path: str, revision: Optional[str] = 'master'): | |||
def read_config(model_id_or_path: str, | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
""" Read config from hub or local path | |||
Args: | |||
model_id_or_path (str): Model repo name or local directory path. | |||
revision: revision of the model when getting from the hub | |||
Return: | |||
config (:obj:`Config`): config object | |||
""" | |||