@@ -1,6 +1,6 @@ | |||
repos: | |||
- repo: https://gitlab.com/pycqa/flake8.git | |||
rev: 3.8.3 | |||
rev: 4.0.0 | |||
hooks: | |||
- id: flake8 | |||
exclude: thirdparty/|examples/ | |||
@@ -1,6 +1,6 @@ | |||
repos: | |||
- repo: /home/admin/pre-commit/flake8 | |||
rev: 3.8.3 | |||
rev: 4.0.0 | |||
hooks: | |||
- id: flake8 | |||
exclude: thirdparty/|examples/ | |||
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f | |||
size 38954 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 | |||
size 93676 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb | |||
size 472012 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b | |||
size 71244 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 | |||
size 181964 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 | |||
size 112078 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 | |||
size 200556 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b | |||
size 162636 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 | |||
size 160204 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 | |||
size 151572 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||
size 62231 | |||
oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b | |||
size 61741 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||
size 62235 | |||
oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 | |||
size 61745 |
@@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
def generate_dummy_inputs(self, | |||
shape: Tuple = None, | |||
pair: bool = False, | |||
**kwargs) -> Dict[str, Any]: | |||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
@param shape: A tuple of input shape which should have at most two dimensions. | |||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
@param pair: Generate sentence pairs or single sentences for dummy inputs. | |||
@return: Dummy inputs. | |||
""" | |||
@@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
**sequence_length | |||
}) | |||
preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | |||
if preprocessor.pair: | |||
if pair: | |||
first_sequence = preprocessor.tokenizer.unk_token | |||
second_sequence = preprocessor.tokenizer.unk_token | |||
else: | |||
@@ -1,8 +1,11 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# yapf: disable | |||
import datetime | |||
import os | |||
import pickle | |||
import shutil | |||
import tempfile | |||
from collections import defaultdict | |||
from http import HTTPStatus | |||
from http.cookiejar import CookieJar | |||
@@ -16,17 +19,25 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||
API_RESPONSE_FIELD_MESSAGE, | |||
API_RESPONSE_FIELD_USERNAME, | |||
DEFAULT_CREDENTIALS_PATH) | |||
DEFAULT_CREDENTIALS_PATH, Licenses, | |||
ModelVisibility) | |||
from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
NotLoginException, RequestError, | |||
datahub_raise_on_error, | |||
handle_http_post_error, | |||
handle_http_response, is_ok, raise_on_error) | |||
from modelscope.hub.git import GitCommandWrapper | |||
from modelscope.hub.repository import Repository | |||
from modelscope.hub.utils.utils import (get_endpoint, | |||
model_id_to_group_owner_name) | |||
from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | |||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
DEFAULT_MODEL_REVISION, | |||
DatasetFormations, DatasetMetaFormats, | |||
DownloadMode) | |||
DownloadMode, ModelFile) | |||
from modelscope.utils.logger import get_logger | |||
from .errors import (InvalidParameter, NotExistError, RequestError, | |||
datahub_raise_on_error, handle_http_post_error, | |||
handle_http_response, is_ok, raise_on_error) | |||
from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||
# yapf: enable | |||
logger = get_logger() | |||
@@ -169,11 +180,106 @@ class HubApi: | |||
else: | |||
r.raise_for_status() | |||
def list_model(self, | |||
owner_or_group: str, | |||
page_number=1, | |||
page_size=10) -> dict: | |||
"""List model in owner or group. | |||
def push_model(self, | |||
model_id: str, | |||
model_dir: str, | |||
visibility: int = ModelVisibility.PUBLIC, | |||
license: str = Licenses.APACHE_V2, | |||
chinese_name: Optional[str] = None, | |||
commit_message: Optional[str] = 'upload model', | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
""" | |||
Upload model from a given directory to given repository. A valid model directory | |||
must contain a configuration.json file. | |||
This function upload the files in given directory to given repository. If the | |||
given repository is not exists in remote, it will automatically create it with | |||
given visibility, license and chinese_name parameters. If the revision is also | |||
not exists in remote repository, it will create a new branch for it. | |||
This function must be called before calling HubApi's login with a valid token | |||
which can be obtained from ModelScope's website. | |||
Args: | |||
model_id (`str`): | |||
The model id to be uploaded, caller must have write permission for it. | |||
model_dir(`str`): | |||
The Absolute Path of the finetune result. | |||
visibility(`int`, defaults to `0`): | |||
Visibility of the new created model(1-private, 5-public). If the model is | |||
not exists in ModelScope, this function will create a new model with this | |||
visibility and this parameter is required. You can ignore this parameter | |||
if you make sure the model's existence. | |||
license(`str`, defaults to `None`): | |||
License of the new created model(see License). If the model is not exists | |||
in ModelScope, this function will create a new model with this license | |||
and this parameter is required. You can ignore this parameter if you | |||
make sure the model's existence. | |||
chinese_name(`str`, *optional*, defaults to `None`): | |||
chinese name of the new created model. | |||
commit_message(`str`, *optional*, defaults to `None`): | |||
commit message of the push request. | |||
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||
which branch to push. If the branch is not exists, It will create a new | |||
branch and push to it. | |||
""" | |||
if model_id is None: | |||
raise InvalidParameter('model_id cannot be empty!') | |||
if model_dir is None: | |||
raise InvalidParameter('model_dir cannot be empty!') | |||
if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||
raise InvalidParameter('model_dir must be a valid directory.') | |||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
if not os.path.exists(cfg_file): | |||
raise ValueError(f'{model_dir} must contain a configuration.json.') | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException('Must login before upload!') | |||
files_to_save = os.listdir(model_dir) | |||
try: | |||
self.get_model(model_id=model_id) | |||
except Exception: | |||
if visibility is None or license is None: | |||
raise InvalidParameter( | |||
'visibility and license cannot be empty if want to create new repo' | |||
) | |||
logger.info('Create new model %s' % model_id) | |||
self.create_model( | |||
model_id=model_id, | |||
visibility=visibility, | |||
license=license, | |||
chinese_name=chinese_name) | |||
tmp_dir = tempfile.mkdtemp() | |||
git_wrapper = GitCommandWrapper() | |||
try: | |||
repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||
branches = git_wrapper.get_remote_branches(tmp_dir) | |||
if revision not in branches: | |||
logger.info('Create new branch %s' % revision) | |||
git_wrapper.new_branch(tmp_dir, revision) | |||
git_wrapper.checkout(tmp_dir, revision) | |||
for f in files_to_save: | |||
if f[0] != '.': | |||
src = os.path.join(model_dir, f) | |||
if os.path.isdir(src): | |||
shutil.copytree(src, os.path.join(tmp_dir, f)) | |||
else: | |||
shutil.copy(src, tmp_dir) | |||
if not commit_message: | |||
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
commit_message = '[automsg] push model %s to hub at %s' % ( | |||
model_id, date) | |||
repo.push(commit_message=commit_message, branch=revision) | |||
except Exception: | |||
raise | |||
finally: | |||
shutil.rmtree(tmp_dir, ignore_errors=True) | |||
def list_models(self, | |||
owner_or_group: str, | |||
page_number=1, | |||
page_size=10) -> dict: | |||
"""List models in owner or group. | |||
Args: | |||
owner_or_group(`str`): owner or group. | |||
@@ -390,11 +496,13 @@ class HubApi: | |||
return resp['Data'] | |||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | |||
is_recursive, is_filter_dir, revision, | |||
cookies): | |||
is_recursive, is_filter_dir, revision): | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | |||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||
cookies = requests.utils.dict_from_cookiejar(cookies) | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies: | |||
cookies = requests.utils.dict_from_cookiejar(cookies) | |||
resp = requests.get(url=url, cookies=cookies) | |||
resp = resp.json() | |||
@@ -11,13 +11,12 @@ from typing import Dict, Optional, Union | |||
from uuid import uuid4 | |||
import requests | |||
from filelock import FileLock | |||
from tqdm import tqdm | |||
from modelscope import __version__ | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import FILE_HASH | |||
from .errors import FileDownloadError, NotExistError | |||
from .utils.caching import ModelFileSystemCache | |||
@@ -1,13 +1,10 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import re | |||
import subprocess | |||
from typing import List | |||
from xmlrpc.client import Boolean | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from .errors import GitError | |||
logger = get_logger() | |||
@@ -132,6 +129,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
return response | |||
def add_user_info(self, repo_base_dir, repo_name): | |||
from modelscope.hub.api import ModelScopeConfig | |||
user_name, user_email = ModelScopeConfig.get_user_info() | |||
if user_name and user_email: | |||
# config user.name and user.email if exist | |||
@@ -184,8 +182,11 @@ class GitCommandWrapper(metaclass=Singleton): | |||
info = [ | |||
line.strip() | |||
for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | |||
][1:] | |||
return ['/'.join(line.split('/')[1:]) for line in info] | |||
] | |||
if len(info) == 1: | |||
return ['/'.join(info[0].split('/')[1:])] | |||
else: | |||
return ['/'.join(line.split('/')[1:]) for line in info[1:]] | |||
def pull(self, repo_dir: str): | |||
cmds = ['-C', repo_dir, 'pull'] | |||
@@ -7,7 +7,6 @@ from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
DEFAULT_MODEL_REVISION) | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from .git import GitCommandWrapper | |||
from .utils.utils import get_endpoint | |||
@@ -47,6 +46,7 @@ class Repository: | |||
err_msg = 'a non-default value of revision cannot be empty.' | |||
raise InvalidParameter(err_msg) | |||
from modelscope.hub.api import ModelScopeConfig | |||
if auth_token: | |||
self.auth_token = auth_token | |||
else: | |||
@@ -166,7 +166,7 @@ class DatasetRepository: | |||
err_msg = 'a non-default value of revision cannot be empty.' | |||
raise InvalidParameter(err_msg) | |||
self.revision = revision | |||
from modelscope.hub.api import ModelScopeConfig | |||
if auth_token: | |||
self.auth_token = auth_token | |||
else: | |||
@@ -5,9 +5,9 @@ import tempfile | |||
from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import FILE_HASH | |||
from .errors import NotExistError | |||
from .file_download import (get_file_download_url, http_get_file, | |||
@@ -1,117 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import datetime | |||
import os | |||
import shutil | |||
import tempfile | |||
import uuid | |||
from typing import Dict, Optional | |||
from uuid import uuid4 | |||
from filelock import FileLock | |||
from modelscope import __version__ | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.hub.errors import InvalidParameter, NotLoginException | |||
from modelscope.hub.git import GitCommandWrapper | |||
from modelscope.hub.repository import Repository | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
def upload_folder(model_id: str, | |||
model_dir: str, | |||
visibility: int = 0, | |||
license: str = None, | |||
chinese_name: Optional[str] = None, | |||
commit_message: Optional[str] = None, | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
""" | |||
Upload model from a given directory to given repository. A valid model directory | |||
must contain a configuration.json file. | |||
This function upload the files in given directory to given repository. If the | |||
given repository is not exists in remote, it will automatically create it with | |||
given visibility, license and chinese_name parameters. If the revision is also | |||
not exists in remote repository, it will create a new branch for it. | |||
This function must be called before calling HubApi's login with a valid token | |||
which can be obtained from ModelScope's website. | |||
Args: | |||
model_id (`str`): | |||
The model id to be uploaded, caller must have write permission for it. | |||
model_dir(`str`): | |||
The Absolute Path of the finetune result. | |||
visibility(`int`, defaults to `0`): | |||
Visibility of the new created model(1-private, 5-public). If the model is | |||
not exists in ModelScope, this function will create a new model with this | |||
visibility and this parameter is required. You can ignore this parameter | |||
if you make sure the model's existence. | |||
license(`str`, defaults to `None`): | |||
License of the new created model(see License). If the model is not exists | |||
in ModelScope, this function will create a new model with this license | |||
and this parameter is required. You can ignore this parameter if you | |||
make sure the model's existence. | |||
chinese_name(`str`, *optional*, defaults to `None`): | |||
chinese name of the new created model. | |||
commit_message(`str`, *optional*, defaults to `None`): | |||
commit message of the push request. | |||
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||
which branch to push. If the branch is not exists, It will create a new | |||
branch and push to it. | |||
""" | |||
if model_id is None: | |||
raise InvalidParameter('model_id cannot be empty!') | |||
if model_dir is None: | |||
raise InvalidParameter('model_dir cannot be empty!') | |||
if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||
raise InvalidParameter('model_dir must be a valid directory.') | |||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
if not os.path.exists(cfg_file): | |||
raise ValueError(f'{model_dir} must contain a configuration.json.') | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException('Must login before upload!') | |||
files_to_save = os.listdir(model_dir) | |||
api = HubApi() | |||
try: | |||
api.get_model(model_id=model_id) | |||
except Exception: | |||
if visibility is None or license is None: | |||
raise InvalidParameter( | |||
'visibility and license cannot be empty if want to create new repo' | |||
) | |||
logger.info('Create new model %s' % model_id) | |||
api.create_model( | |||
model_id=model_id, | |||
visibility=visibility, | |||
license=license, | |||
chinese_name=chinese_name) | |||
tmp_dir = tempfile.mkdtemp() | |||
git_wrapper = GitCommandWrapper() | |||
try: | |||
repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||
branches = git_wrapper.get_remote_branches(tmp_dir) | |||
if revision not in branches: | |||
logger.info('Create new branch %s' % revision) | |||
git_wrapper.new_branch(tmp_dir, revision) | |||
git_wrapper.checkout(tmp_dir, revision) | |||
for f in files_to_save: | |||
if f[0] != '.': | |||
src = os.path.join(model_dir, f) | |||
if os.path.isdir(src): | |||
shutil.copytree(src, os.path.join(tmp_dir, f)) | |||
else: | |||
shutil.copy(src, tmp_dir) | |||
if not commit_message: | |||
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
commit_message = '[automsg] push model %s to hub at %s' % ( | |||
model_id, date) | |||
repo.push(commit_message=commit_message, branch=revision) | |||
except Exception: | |||
raise | |||
finally: | |||
shutil.rmtree(tmp_dir, ignore_errors=True) |
@@ -9,7 +9,9 @@ class Models(object): | |||
Model name should only contain model info but not task info. | |||
""" | |||
# tinynas models | |||
tinynas_detection = 'tinynas-detection' | |||
tinynas_damoyolo = 'tinynas-damoyolo' | |||
# vision models | |||
detection = 'detection' | |||
@@ -454,9 +456,9 @@ class Datasets(object): | |||
""" Names for different datasets. | |||
""" | |||
ClsDataset = 'ClsDataset' | |||
Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||
Face2dKeypointsDataset = 'FaceKeypointDataset' | |||
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | |||
HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | |||
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||
SegDataset = 'SegDataset' | |||
DetDataset = 'DetDataset' | |||
DetImagesMixDataset = 'DetImagesMixDataset' |
@@ -32,6 +32,7 @@ task_default_metrics = { | |||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
Tasks.token_classification: [Metrics.token_cls_metric], | |||
Tasks.text_generation: [Metrics.text_gen_metric], | |||
Tasks.text_classification: [Metrics.seq_cls_metric], | |||
Tasks.image_denoising: [Metrics.image_denoise_metric], | |||
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | |||
Tasks.image_portrait_enhancement: | |||
@@ -2,6 +2,7 @@ | |||
import os | |||
import pickle as pkl | |||
from threading import Lock | |||
import json | |||
import numpy as np | |||
@@ -27,6 +28,7 @@ class Voice: | |||
self.__am_config = AttrDict(**am_config) | |||
self.__voc_config = AttrDict(**voc_config) | |||
self.__model_loaded = False | |||
self.__lock = Lock() | |||
if 'am' not in self.__am_config: | |||
raise TtsModelConfigurationException( | |||
'modelscope error: am configuration invalid') | |||
@@ -71,34 +73,35 @@ class Voice: | |||
self.__generator.remove_weight_norm() | |||
def __am_forward(self, symbol_seq): | |||
with torch.no_grad(): | |||
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||
symbol_seq) | |||
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||
self.__device) | |||
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||
self.__device) | |||
inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( | |||
self.__device) | |||
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||
self.__device) | |||
inputs_ling = torch.stack( | |||
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||
dim=-1).unsqueeze(0) | |||
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_len = torch.zeros(1).to(self.__device).long( | |||
) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||
inputs_spk[:, :-1], inputs_len) | |||
postnet_outputs = res['postnet_outputs'] | |||
LR_length_rounded = res['LR_length_rounded'] | |||
valid_length = int(LR_length_rounded[0].item()) | |||
postnet_outputs = postnet_outputs[ | |||
0, :valid_length, :].cpu().numpy() | |||
return postnet_outputs | |||
with self.__lock: | |||
with torch.no_grad(): | |||
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||
symbol_seq) | |||
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||
self.__device) | |||
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||
self.__device) | |||
inputs_syllable = torch.from_numpy( | |||
inputs_feat_lst[2]).long().to(self.__device) | |||
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||
self.__device) | |||
inputs_ling = torch.stack( | |||
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||
dim=-1).unsqueeze(0) | |||
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_len = torch.zeros(1).to(self.__device).long( | |||
) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||
inputs_spk[:, :-1], inputs_len) | |||
postnet_outputs = res['postnet_outputs'] | |||
LR_length_rounded = res['LR_length_rounded'] | |||
valid_length = int(LR_length_rounded[0].item()) | |||
postnet_outputs = postnet_outputs[ | |||
0, :valid_length, :].cpu().numpy() | |||
return postnet_outputs | |||
def __vocoder_forward(self, melspec): | |||
dim0 = list(melspec.shape)[-1] | |||
@@ -118,14 +121,15 @@ class Voice: | |||
return audio | |||
def forward(self, symbol_seq): | |||
if not self.__model_loaded: | |||
torch.manual_seed(self.__am_config.seed) | |||
if torch.cuda.is_available(): | |||
with self.__lock: | |||
if not self.__model_loaded: | |||
torch.manual_seed(self.__am_config.seed) | |||
self.__device = torch.device('cuda') | |||
else: | |||
self.__device = torch.device('cpu') | |||
self.__load_am() | |||
self.__load_vocoder() | |||
self.__model_loaded = True | |||
if torch.cuda.is_available(): | |||
torch.manual_seed(self.__am_config.seed) | |||
self.__device = torch.device('cuda') | |||
else: | |||
self.__device = torch.device('cpu') | |||
self.__load_am() | |||
self.__load_vocoder() | |||
self.__model_loaded = True | |||
return self.__vocoder_forward(self.__am_forward(symbol_seq)) |
@@ -93,7 +93,7 @@ class TextDrivenSeg(TorchModel): | |||
""" | |||
with torch.no_grad(): | |||
if self.device_id == -1: | |||
output = self.model(image) | |||
output = self.model(image, [text]) | |||
else: | |||
device = torch.device('cuda', self.device_id) | |||
output = self.model(image.to(device), [text]) | |||
@@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .tinynas_detector import Tinynas_detector | |||
from .tinynas_damoyolo import DamoYolo | |||
else: | |||
_import_structure = { | |||
'tinynas_detector': ['TinynasDetector'], | |||
'tinynas_damoyolo': ['DamoYolo'], | |||
} | |||
import sys | |||
@@ -4,6 +4,7 @@ | |||
import torch | |||
import torch.nn as nn | |||
from modelscope.utils.file_utils import read_file | |||
from ..core.base_ops import Focus, SPPBottleneck, get_activation | |||
from ..core.repvgg_block import RepVggBlock | |||
@@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module): | |||
kernel_size, | |||
stride, | |||
force_resproj=False, | |||
act='silu'): | |||
act='silu', | |||
reparam=False): | |||
super(ResConvK1KX, self).__init__() | |||
self.stride = stride | |||
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | |||
self.conv2 = RepVggBlock( | |||
btn_c, out_c, kernel_size, stride, act='identity') | |||
if not reparam: | |||
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) | |||
else: | |||
self.conv2 = RepVggBlock( | |||
btn_c, out_c, kernel_size, stride, act='identity') | |||
if act is None: | |||
self.activation_function = torch.relu | |||
@@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module): | |||
stride, | |||
num_blocks, | |||
with_spp=False, | |||
act='silu'): | |||
act='silu', | |||
reparam=False): | |||
super(SuperResConvK1KX, self).__init__() | |||
if act is None: | |||
self.act = torch.relu | |||
@@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module): | |||
this_kernel_size, | |||
this_stride, | |||
force_resproj, | |||
act=act) | |||
act=act, | |||
reparam=reparam) | |||
self.block_list.append(the_block) | |||
if block_id == 0 and with_spp: | |||
self.block_list.append( | |||
@@ -248,7 +255,8 @@ class TinyNAS(nn.Module): | |||
with_spp=False, | |||
use_focus=False, | |||
need_conv1=True, | |||
act='silu'): | |||
act='silu', | |||
reparam=False): | |||
super(TinyNAS, self).__init__() | |||
assert len(out_indices) == len(out_channels) | |||
self.out_indices = out_indices | |||
@@ -281,7 +289,8 @@ class TinyNAS(nn.Module): | |||
block_info['s'], | |||
block_info['L'], | |||
spp, | |||
act=act) | |||
act=act, | |||
reparam=reparam) | |||
self.block_list.append(the_block) | |||
elif the_block_class == 'SuperResConvKXKX': | |||
spp = with_spp if idx == len(structure_info) - 1 else False | |||
@@ -325,8 +334,8 @@ class TinyNAS(nn.Module): | |||
def load_tinynas_net(backbone_cfg): | |||
# load masternet model to path | |||
import ast | |||
struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) | |||
net_structure_str = read_file(backbone_cfg.structure_file) | |||
struct_str = ''.join([x.strip() for x in net_structure_str]) | |||
struct_info = ast.literal_eval(struct_str) | |||
for layer in struct_info: | |||
if 'nbitsA' in layer: | |||
@@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg): | |||
use_focus=backbone_cfg.use_focus, | |||
act=backbone_cfg.act, | |||
need_conv1=backbone_cfg.need_conv1, | |||
) | |||
reparam=backbone_cfg.reparam) | |||
return model |
@@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel): | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
config_path = osp.join(model_dir, 'airdet_s.py') | |||
config_path = osp.join(model_dir, self.config_name) | |||
config = parse_config(config_path) | |||
self.cfg = config | |||
model_path = osp.join(model_dir, config.model.name) | |||
@@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel): | |||
self.conf_thre = config.model.head.nms_conf_thre | |||
self.nms_thre = config.model.head.nms_iou_thre | |||
if self.cfg.model.backbone.name == 'TinyNAS': | |||
self.cfg.model.backbone.structure_file = osp.join( | |||
model_dir, self.cfg.model.backbone.structure_file) | |||
self.backbone = build_backbone(self.cfg.model.backbone) | |||
self.neck = build_neck(self.cfg.model.neck) | |||
self.head = build_head(self.cfg.model.head) | |||
@@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module): | |||
simOTA_iou_weight=3.0, | |||
octbase=8, | |||
simlqe=False, | |||
use_lqe=True, | |||
**kwargs): | |||
self.simlqe = simlqe | |||
self.num_classes = num_classes | |||
self.in_channels = in_channels | |||
self.strides = strides | |||
self.use_lqe = use_lqe | |||
self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | |||
else [feat_channels] * len(self.strides) | |||
@@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module): | |||
groups=self.conv_groups, | |||
norm=self.norm, | |||
act=self.act)) | |||
if not self.simlqe: | |||
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] | |||
if self.use_lqe: | |||
if not self.simlqe: | |||
conf_vector = [ | |||
nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) | |||
] | |||
else: | |||
conf_vector = [ | |||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||
] | |||
conf_vector += [self.relu] | |||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||
reg_conf = nn.Sequential(*conf_vector) | |||
else: | |||
conf_vector = [ | |||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||
] | |||
conf_vector += [self.relu] | |||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||
reg_conf = nn.Sequential(*conf_vector) | |||
reg_conf = None | |||
return cls_convs, reg_convs, reg_conf | |||
@@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module): | |||
N, C, H, W = bbox_pred.size() | |||
prob = F.softmax( | |||
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | |||
if not self.simlqe: | |||
prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||
if self.add_mean: | |||
stat = torch.cat( | |||
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) | |||
if self.use_lqe: | |||
if not self.simlqe: | |||
prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||
if self.add_mean: | |||
stat = torch.cat( | |||
[prob_topk, | |||
prob_topk.mean(dim=2, keepdim=True)], | |||
dim=2) | |||
else: | |||
stat = prob_topk | |||
quality_score = reg_conf( | |||
stat.reshape(N, 4 * self.total_dim, H, W)) | |||
else: | |||
stat = prob_topk | |||
quality_score = reg_conf( | |||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||
quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) | |||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||
else: | |||
quality_score = reg_conf( | |||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||
cls_score = gfl_cls(cls_feat).sigmoid() | |||
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | |||
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | |||
@@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module): | |||
self, | |||
depth=1.0, | |||
width=1.0, | |||
in_features=[2, 3, 4], | |||
in_channels=[256, 512, 1024], | |||
out_channels=[256, 512, 1024], | |||
depthwise=False, | |||
@@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module): | |||
block_name='BasicBlock', | |||
): | |||
super().__init__() | |||
self.in_features = in_features | |||
self.in_channels = in_channels | |||
Conv = DWConv if depthwise else BaseConv | |||
@@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module): | |||
""" | |||
# backbone | |||
features = [out_features[f] for f in self.in_features] | |||
[x2, x1, x0] = features | |||
[x2, x1, x0] = out_features | |||
# node x3 | |||
x13 = self.bu_conv13(x1) | |||
@@ -0,0 +1,15 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import Tasks | |||
from .detector import SingleStageDetector | |||
@MODELS.register_module( | |||
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) | |||
class DamoYolo(SingleStageDetector): | |||
def __init__(self, model_dir, *args, **kwargs): | |||
self.config_name = 'damoyolo_s.py' | |||
super(DamoYolo, self).__init__(model_dir, *args, **kwargs) |
@@ -12,5 +12,5 @@ from .detector import SingleStageDetector | |||
class TinynasDetector(SingleStageDetector): | |||
def __init__(self, model_dir, *args, **kwargs): | |||
self.config_name = 'airdet_s.py' | |||
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) |
@@ -15,7 +15,6 @@ | |||
"""PyTorch BERT model. """ | |||
import math | |||
import os | |||
import warnings | |||
from dataclasses import dataclass | |||
from typing import Optional, Tuple | |||
@@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel, | |||
find_pruneable_heads_and_indices, | |||
prune_linear_layer) | |||
from modelscope.models.base import TorchModel | |||
from modelscope.utils.logger import get_logger | |||
from .configuration_bert import BertConfig | |||
@@ -50,81 +48,6 @@ logger = get_logger(__name__) | |||
_CONFIG_FOR_DOC = 'BertConfig' | |||
def load_tf_weights_in_bert(model, config, tf_checkpoint_path): | |||
"""Load tf checkpoints in a pytorch model.""" | |||
try: | |||
import re | |||
import numpy as np | |||
import tensorflow as tf | |||
except ImportError: | |||
logger.error( | |||
'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' | |||
'https://www.tensorflow.org/install/ for installation instructions.' | |||
) | |||
raise | |||
tf_path = os.path.abspath(tf_checkpoint_path) | |||
logger.info(f'Converting TensorFlow checkpoint from {tf_path}') | |||
# Load weights from TF model | |||
init_vars = tf.train.list_variables(tf_path) | |||
names = [] | |||
arrays = [] | |||
for name, shape in init_vars: | |||
logger.info(f'Loading TF weight {name} with shape {shape}') | |||
array = tf.train.load_variable(tf_path, name) | |||
names.append(name) | |||
arrays.append(array) | |||
for name, array in zip(names, arrays): | |||
name = name.split('/') | |||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v | |||
# which are not required for using pretrained model | |||
if any(n in [ | |||
'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', | |||
'AdamWeightDecayOptimizer_1', 'global_step' | |||
] for n in name): | |||
logger.info(f"Skipping {'/'.join(name)}") | |||
continue | |||
pointer = model | |||
for m_name in name: | |||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name): | |||
scope_names = re.split(r'_(\d+)', m_name) | |||
else: | |||
scope_names = [m_name] | |||
if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': | |||
pointer = getattr(pointer, 'weight') | |||
elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': | |||
pointer = getattr(pointer, 'bias') | |||
elif scope_names[0] == 'output_weights': | |||
pointer = getattr(pointer, 'weight') | |||
elif scope_names[0] == 'squad': | |||
pointer = getattr(pointer, 'classifier') | |||
else: | |||
try: | |||
pointer = getattr(pointer, scope_names[0]) | |||
except AttributeError: | |||
logger.info(f"Skipping {'/'.join(name)}") | |||
continue | |||
if len(scope_names) >= 2: | |||
num = int(scope_names[1]) | |||
pointer = pointer[num] | |||
if m_name[-11:] == '_embeddings': | |||
pointer = getattr(pointer, 'weight') | |||
elif m_name == 'kernel': | |||
array = np.transpose(array) | |||
try: | |||
if pointer.shape != array.shape: | |||
raise ValueError( | |||
f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' | |||
) | |||
except AssertionError as e: | |||
e.args += (pointer.shape, array.shape) | |||
raise | |||
logger.info(f'Initialize PyTorch weight {name}') | |||
pointer.data = torch.from_numpy(array) | |||
return model | |||
class BertEmbeddings(nn.Module): | |||
"""Construct the embeddings from word, position and token_type embeddings.""" | |||
@@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel): | |||
""" | |||
config_class = BertConfig | |||
load_tf_weights = load_tf_weights_in_bert | |||
base_model_prefix = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
@@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): | |||
if self.split_config is not None: | |||
self._update_data_source(kwargs['data_source']) | |||
def _update_data_root(self, input_dict, data_root): | |||
for k, v in input_dict.items(): | |||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||
input_dict.update( | |||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||
elif isinstance(v, dict): | |||
self._update_data_root(v, data_root) | |||
def _update_data_source(self, data_source): | |||
data_root = next(iter(self.split_config.values())) | |||
data_root = data_root.rstrip(osp.sep) | |||
for k, v in data_source.items(): | |||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||
data_source.update( | |||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||
self._update_data_root(data_source, data_root) |
@@ -7,7 +7,7 @@ from typing import Any, Mapping, Optional, Sequence, Union | |||
from datasets.builder import DatasetBuilder | |||
from modelscope.hub.api import HubApi | |||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, DownloadParams | |||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder | |||
@@ -95,15 +95,13 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, | |||
res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] | |||
""" | |||
res = [] | |||
cookies = hub_api.check_cookies_upload_data(use_cookies=True) | |||
objects = hub_api.list_oss_dataset_objects( | |||
dataset_name=dataset_name, | |||
namespace=namespace, | |||
max_limit=max_limit, | |||
is_recursive=is_recursive, | |||
is_filter_dir=True, | |||
revision=version, | |||
cookies=cookies) | |||
revision=version) | |||
for item in objects: | |||
object_key = item.get('Key') | |||
@@ -174,7 +172,7 @@ def get_dataset_files(subset_split_into: dict, | |||
modelscope_api = HubApi() | |||
objects = list_dataset_objects( | |||
hub_api=modelscope_api, | |||
max_limit=DownloadParams.MAX_LIST_OBJECTS_NUM.value, | |||
max_limit=-1, | |||
is_recursive=True, | |||
dataset_name=dataset_name, | |||
namespace=namespace, | |||
@@ -47,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
if isinstance(audio_in, str): | |||
# load pcm data from url if audio_in is url str | |||
self.audio_in = load_bytes_from_url(audio_in) | |||
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) | |||
elif isinstance(audio_in, bytes): | |||
# load pcm data from wav data if audio_in is wave format | |||
self.audio_in = extract_pcm_from_wav(audio_in) | |||
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) | |||
else: | |||
self.audio_in = audio_in | |||
# set the sample_rate of audio_in if checking_audio_fs is valid | |||
if checking_audio_fs is not None: | |||
self.audio_fs = checking_audio_fs | |||
if recog_type is None or audio_format is None: | |||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||
audio_in=self.audio_in, | |||
recog_type=recog_type, | |||
audio_format=audio_format) | |||
if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None: | |||
self.audio_fs = asr_utils.sample_rate_checking( | |||
if hasattr(asr_utils, 'sample_rate_checking'): | |||
checking_audio_fs = asr_utils.sample_rate_checking( | |||
self.audio_in, self.audio_format) | |||
if checking_audio_fs is not None: | |||
self.audio_fs = checking_audio_fs | |||
if self.preprocessor is None: | |||
self.preprocessor = WavToScp() | |||
@@ -80,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||
data_cmd: Sequence[Tuple[str, str]] | |||
data_cmd: Sequence[Tuple[str, str, str]] | |||
if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | |||
data_cmd = ['speech', 'sound'] | |||
elif inputs['audio_format'] == 'kaldi_ark': | |||
@@ -88,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
elif inputs['audio_format'] == 'tfrecord': | |||
data_cmd = ['speech', 'tfrecord'] | |||
if inputs.__contains__('mvn_file'): | |||
data_cmd.append(inputs['mvn_file']) | |||
# generate asr inference command | |||
cmd = { | |||
'model_type': inputs['model_type'], | |||
@@ -51,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
if isinstance(audio_in, str): | |||
# load pcm data from url if audio_in is url str | |||
audio_in = load_bytes_from_url(audio_in) | |||
audio_in, audio_fs = load_bytes_from_url(audio_in) | |||
elif isinstance(audio_in, bytes): | |||
# load pcm data from wav data if audio_in is wave format | |||
audio_in = extract_pcm_from_wav(audio_in) | |||
audio_in, audio_fs = extract_pcm_from_wav(audio_in) | |||
output = self.preprocessor.forward(self.model.forward(), audio_in) | |||
output = self.forward(output) | |||
@@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import \ | |||
show_image_object_detection_auto_result | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline): | |||
bboxes, scores, labels = self.model.postprocess(inputs['data']) | |||
if bboxes is None: | |||
return None | |||
outputs = { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.LABELS: labels, | |||
OutputKeys.BOXES: bboxes | |||
} | |||
outputs = { | |||
OutputKeys.SCORES: [], | |||
OutputKeys.LABELS: [], | |||
OutputKeys.BOXES: [] | |||
} | |||
else: | |||
outputs = { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.LABELS: labels, | |||
OutputKeys.BOXES: bboxes | |||
} | |||
return outputs | |||
def show_result(self, img_path, result, save_path=None): | |||
show_image_object_detection_auto_result(img_path, result, save_path) |
@@ -133,6 +133,12 @@ class WavToScp(Preprocessor): | |||
else: | |||
inputs['asr_model_config'] = asr_model_config | |||
if inputs['model_config'].__contains__('mvn_file'): | |||
mvn_file = os.path.join(inputs['model_workspace'], | |||
inputs['model_config']['mvn_file']) | |||
assert os.path.exists(mvn_file), 'mvn_file does not exist' | |||
inputs['mvn_file'] = mvn_file | |||
elif inputs['model_type'] == Frameworks.tf: | |||
assert inputs['model_config'].__contains__( | |||
'vocab_file'), 'vocab_file does not exist' | |||
@@ -2,7 +2,7 @@ | |||
import os.path as osp | |||
import re | |||
from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |||
import numpy as np | |||
import sentencepiece as spm | |||
@@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
return isinstance(label, str) or isinstance(label, int) | |||
if labels is not None: | |||
if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ | |||
if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ | |||
and self.label2id is not None: | |||
output[OutputKeys.LABELS] = [ | |||
self.label2id[str(label)] for label in labels | |||
@@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
kwargs['truncation'] = kwargs.get('truncation', True) | |||
kwargs['padding'] = kwargs.get( | |||
'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||
kwargs['padding'] = kwargs.get('padding', 'max_length') | |||
kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
super().__init__(model_dir, mode=mode, **kwargs) | |||
@@ -1,5 +1,10 @@ | |||
import math | |||
import os | |||
import random | |||
import uuid | |||
from os.path import exists | |||
from tempfile import TemporaryDirectory | |||
from urllib.parse import urlparse | |||
import numpy as np | |||
import torch | |||
@@ -9,6 +14,7 @@ import torchvision.transforms._transforms_video as transforms | |||
from decord import VideoReader | |||
from torchvision.transforms import Compose | |||
from modelscope.hub.file_download import http_get_file | |||
from modelscope.metainfo import Preprocessors | |||
from modelscope.utils.constant import Fields, ModeKeys | |||
from modelscope.utils.type_assert import type_assert | |||
@@ -30,7 +36,22 @@ def ReadVideoData(cfg, | |||
Returns: | |||
data (Tensor): the normalized video clips for model inputs | |||
""" | |||
data = _decode_video(cfg, video_path, num_temporal_views_override) | |||
url_parsed = urlparse(video_path) | |||
if url_parsed.scheme in ('file', '') and exists( | |||
url_parsed.path): # Possibly a local file | |||
data = _decode_video(cfg, video_path, num_temporal_views_override) | |||
else: | |||
with TemporaryDirectory() as temporary_cache_dir: | |||
random_str = uuid.uuid4().hex | |||
http_get_file( | |||
url=video_path, | |||
local_dir=temporary_cache_dir, | |||
file_name=random_str, | |||
cookies=None) | |||
temp_file_path = os.path.join(temporary_cache_dir, random_str) | |||
data = _decode_video(cfg, temp_file_path, | |||
num_temporal_views_override) | |||
if num_spatial_crops_override is not None: | |||
num_spatial_crops = num_spatial_crops_override | |||
transform = kinetics400_tranform(cfg, num_spatial_crops_override) | |||
@@ -47,7 +47,7 @@ class LrSchedulerHook(Hook): | |||
return lr | |||
def before_train_iter(self, trainer): | |||
if not self.by_epoch: | |||
if not self.by_epoch and trainer.iter > 0: | |||
if self.warmup_lr_scheduler is not None: | |||
self.warmup_lr_scheduler.step() | |||
else: | |||
@@ -656,7 +656,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
# TODO: support MsDataset load for cv | |||
if hasattr(data_cfg, 'name'): | |||
dataset = MsDataset.load( | |||
dataset_name=data_cfg.name, | |||
dataset_name=data_cfg.pop('name'), | |||
**data_cfg, | |||
) | |||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||
@@ -57,6 +57,7 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||
def extract_pcm_from_wav(wav: bytes) -> bytes: | |||
data = wav | |||
sample_rate = None | |||
if len(data) > 44: | |||
frame_len = 44 | |||
file_len = len(data) | |||
@@ -70,29 +71,33 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: | |||
'Subchunk1ID'] == 'fmt ': | |||
header_fields['SubChunk1Size'] = struct.unpack( | |||
'<I', data[16:20])[0] | |||
header_fields['SampleRate'] = struct.unpack('<I', | |||
data[24:28])[0] | |||
sample_rate = header_fields['SampleRate'] | |||
if header_fields['SubChunk1Size'] == 16: | |||
frame_len = 44 | |||
elif header_fields['SubChunk1Size'] == 18: | |||
frame_len = 46 | |||
else: | |||
return data | |||
return data, sample_rate | |||
data = wav[frame_len:file_len] | |||
except Exception: | |||
# no treatment | |||
pass | |||
return data | |||
return data, sample_rate | |||
def load_bytes_from_url(url: str) -> Union[bytes, str]: | |||
sample_rate = None | |||
result = urlparse(url) | |||
if result.scheme is not None and len(result.scheme) > 0: | |||
storage = HTTPStorage() | |||
data = storage.read(url) | |||
data = extract_pcm_from_wav(data) | |||
data, sample_rate = extract_pcm_from_wav(data) | |||
else: | |||
data = url | |||
return data | |||
return data, sample_rate |
@@ -231,13 +231,6 @@ class DownloadMode(enum.Enum): | |||
FORCE_REDOWNLOAD = 'force_redownload' | |||
class DownloadParams(enum.Enum): | |||
""" | |||
Parameters for downloading dataset. | |||
""" | |||
MAX_LIST_OBJECTS_NUM = 50000 | |||
class DatasetFormations(enum.Enum): | |||
""" How a dataset is organized and interpreted | |||
""" | |||
@@ -61,8 +61,8 @@ def device_placement(framework, device_name='gpu:0'): | |||
if framework == Frameworks.tf: | |||
import tensorflow as tf | |||
if device_type == Devices.gpu and not tf.test.is_gpu_available(): | |||
logger.warning( | |||
'tensorflow cuda is not available, using cpu instead.') | |||
logger.debug( | |||
'tensorflow: cuda is not available, using cpu instead.') | |||
device_type = Devices.cpu | |||
if device_type == Devices.cpu: | |||
with tf.device('/CPU:0'): | |||
@@ -78,7 +78,8 @@ def device_placement(framework, device_name='gpu:0'): | |||
if torch.cuda.is_available(): | |||
torch.cuda.set_device(f'cuda:{device_id}') | |||
else: | |||
logger.warning('cuda is not available, using cpu instead.') | |||
logger.debug( | |||
'pytorch: cuda is not available, using cpu instead.') | |||
yield | |||
else: | |||
yield | |||
@@ -96,9 +97,7 @@ def create_device(device_name): | |||
if device_type == Devices.gpu: | |||
use_cuda = True | |||
if not torch.cuda.is_available(): | |||
logger.warning( | |||
'cuda is not available, create gpu device failed, using cpu instead.' | |||
) | |||
logger.info('cuda is not available, using cpu instead.') | |||
use_cuda = False | |||
if use_cuda: | |||
@@ -1,6 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import inspect | |||
import os | |||
from pathlib import Path | |||
@@ -35,3 +36,10 @@ def get_default_cache_dir(): | |||
""" | |||
default_cache_dir = Path.home().joinpath('.cache', 'modelscope') | |||
return default_cache_dir | |||
def read_file(path): | |||
with open(path, 'r') as f: | |||
text = f.read() | |||
return text |
@@ -176,7 +176,7 @@ def build_from_cfg(cfg, | |||
raise TypeError('default_args must be a dict or None, ' | |||
f'but got {type(default_args)}') | |||
# dynamic load installation reqruiements for this module | |||
# dynamic load installation requirements for this module | |||
from modelscope.utils.import_utils import LazyImportModule | |||
sig = (registry.name.upper(), group_key, cfg['type']) | |||
LazyImportModule.import_module(sig) | |||
@@ -193,8 +193,11 @@ def build_from_cfg(cfg, | |||
if isinstance(obj_type, str): | |||
obj_cls = registry.get(obj_type, group_key=group_key) | |||
if obj_cls is None: | |||
raise KeyError(f'{obj_type} is not in the {registry.name}' | |||
f' registry group {group_key}') | |||
raise KeyError( | |||
f'{obj_type} is not in the {registry.name}' | |||
f' registry group {group_key}. Please make' | |||
f' sure the correct version of 1qqQModelScope library is used.' | |||
) | |||
obj_cls.group_key = group_key | |||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
obj_cls = obj_type | |||
@@ -65,7 +65,8 @@ class RegressTool: | |||
def monitor_module_single_forward(self, | |||
module: nn.Module, | |||
file_name: str, | |||
compare_fn=None): | |||
compare_fn=None, | |||
**kwargs): | |||
"""Monitor a pytorch module in a single forward. | |||
@param module: A torch module | |||
@@ -107,7 +108,7 @@ class RegressTool: | |||
baseline = os.path.join(tempfile.gettempdir(), name) | |||
self.load(baseline, name) | |||
with open(baseline, 'rb') as f: | |||
baseline_json = pickle.load(f) | |||
base = pickle.load(f) | |||
class NumpyEncoder(json.JSONEncoder): | |||
"""Special json encoder for numpy types | |||
@@ -122,9 +123,9 @@ class RegressTool: | |||
return obj.tolist() | |||
return json.JSONEncoder.default(self, obj) | |||
print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') | |||
print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') | |||
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') | |||
if not compare_io_and_print(baseline_json, io_json, compare_fn): | |||
if not compare_io_and_print(base, io_json, compare_fn, **kwargs): | |||
raise ValueError('Result not match!') | |||
@contextlib.contextmanager | |||
@@ -136,7 +137,8 @@ class RegressTool: | |||
ignore_keys=None, | |||
compare_random=True, | |||
reset_dropout=True, | |||
lazy_stop_callback=None): | |||
lazy_stop_callback=None, | |||
**kwargs): | |||
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer. | |||
This is usually useful when you try to change some dangerous code | |||
@@ -265,14 +267,15 @@ class RegressTool: | |||
baseline_json = pickle.load(f) | |||
if level == 'strict' and not compare_io_and_print( | |||
baseline_json['forward'], io_json, compare_fn): | |||
baseline_json['forward'], io_json, compare_fn, **kwargs): | |||
raise RuntimeError('Forward not match!') | |||
if not compare_backward_and_print( | |||
baseline_json['backward'], | |||
bw_json, | |||
compare_fn=compare_fn, | |||
ignore_keys=ignore_keys, | |||
level=level): | |||
level=level, | |||
**kwargs): | |||
raise RuntimeError('Backward not match!') | |||
cfg_opt1 = { | |||
'optimizer': baseline_json['optimizer'], | |||
@@ -286,7 +289,8 @@ class RegressTool: | |||
'cfg': summary['cfg'], | |||
'state': None if not compare_random else summary['state'] | |||
} | |||
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): | |||
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, | |||
**kwargs): | |||
raise RuntimeError('Cfg or optimizers not match!') | |||
@@ -303,7 +307,8 @@ class MsRegressTool(RegressTool): | |||
compare_fn=None, | |||
ignore_keys=None, | |||
compare_random=True, | |||
lazy_stop_callback=None): | |||
lazy_stop_callback=None, | |||
**kwargs): | |||
if lazy_stop_callback is None: | |||
@@ -319,7 +324,7 @@ class MsRegressTool(RegressTool): | |||
trainer.register_hook(EarlyStopHook()) | |||
def _train_loop(trainer, *args, **kwargs): | |||
def _train_loop(trainer, *args_train, **kwargs_train): | |||
with self.monitor_module_train( | |||
trainer, | |||
file_name, | |||
@@ -327,9 +332,11 @@ class MsRegressTool(RegressTool): | |||
compare_fn=compare_fn, | |||
ignore_keys=ignore_keys, | |||
compare_random=compare_random, | |||
lazy_stop_callback=lazy_stop_callback): | |||
lazy_stop_callback=lazy_stop_callback, | |||
**kwargs): | |||
try: | |||
return trainer.train_loop_origin(*args, **kwargs) | |||
return trainer.train_loop_origin(*args_train, | |||
**kwargs_train) | |||
except MsRegressTool.EarlyStopError: | |||
pass | |||
@@ -530,7 +537,8 @@ def compare_arguments_nested(print_content, | |||
) | |||
return False | |||
if not all([ | |||
compare_arguments_nested(None, sub_arg1, sub_arg2) | |||
compare_arguments_nested( | |||
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) | |||
for sub_arg1, sub_arg2 in zip(arg1, arg2) | |||
]): | |||
if print_content is not None: | |||
@@ -551,7 +559,8 @@ def compare_arguments_nested(print_content, | |||
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') | |||
return False | |||
if not all([ | |||
compare_arguments_nested(None, arg1[key], arg2[key]) | |||
compare_arguments_nested( | |||
None, arg1[key], arg2[key], rtol=rtol, atol=atol) | |||
for key in keys1 | |||
]): | |||
if print_content is not None: | |||
@@ -574,7 +583,7 @@ def compare_arguments_nested(print_content, | |||
raise ValueError(f'type not supported: {type1}') | |||
def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): | |||
if compare_fn is None: | |||
def compare_fn(*args, **kwargs): | |||
@@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||
else: | |||
match = compare_arguments_nested( | |||
f'unmatched module {key} input args', v1input['args'], | |||
v2input['args']) and match | |||
v2input['args'], **kwargs) and match | |||
match = compare_arguments_nested( | |||
f'unmatched module {key} input kwargs', v1input['kwargs'], | |||
v2input['kwargs']) and match | |||
v2input['kwargs'], **kwargs) and match | |||
v1output = numpify_tensor_nested(v1['output']) | |||
v2output = numpify_tensor_nested(v2['output']) | |||
res = compare_fn(v1output, v2output, key, 'output') | |||
@@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||
) | |||
match = match and res | |||
else: | |||
match = compare_arguments_nested(f'unmatched module {key} outputs', | |||
v1output, v2output) and match | |||
match = compare_arguments_nested( | |||
f'unmatched module {key} outputs', | |||
arg1=v1output, | |||
arg2=v2output, | |||
**kwargs) and match | |||
return match | |||
@@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json, | |||
bw_json, | |||
level, | |||
ignore_keys=None, | |||
compare_fn=None): | |||
compare_fn=None, | |||
**kwargs): | |||
if compare_fn is None: | |||
def compare_fn(*args, **kwargs): | |||
@@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json, | |||
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ | |||
'grad'], bw_json[key]['data_after'] | |||
match = compare_arguments_nested( | |||
f'unmatched module {key} tensor data', data1, data2) and match | |||
f'unmatched module {key} tensor data', | |||
arg1=data1, | |||
arg2=data2, | |||
**kwargs) and match | |||
if level == 'strict': | |||
match = compare_arguments_nested( | |||
f'unmatched module {key} grad data', grad1, | |||
grad2) and match | |||
f'unmatched module {key} grad data', | |||
arg1=grad1, | |||
arg2=grad2, | |||
**kwargs) and match | |||
match = compare_arguments_nested( | |||
f'unmatched module {key} data after step', data_after1, | |||
data_after2) and match | |||
data_after2, **kwargs) and match | |||
return match | |||
def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||
def compare_cfg_and_optimizers(baseline_json, | |||
cfg_json, | |||
compare_fn=None, | |||
**kwargs): | |||
if compare_fn is None: | |||
def compare_fn(*args, **kwargs): | |||
@@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||
print( | |||
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" | |||
) | |||
match = compare_arguments_nested('unmatched optimizer defaults', | |||
optimizer1['defaults'], | |||
optimizer2['defaults']) and match | |||
match = compare_arguments_nested('unmatched optimizer state_dict', | |||
optimizer1['state_dict'], | |||
optimizer2['state_dict']) and match | |||
match = compare_arguments_nested( | |||
'unmatched optimizer defaults', optimizer1['defaults'], | |||
optimizer2['defaults'], **kwargs) and match | |||
match = compare_arguments_nested( | |||
'unmatched optimizer state_dict', optimizer1['state_dict'], | |||
optimizer2['state_dict'], **kwargs) and match | |||
res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') | |||
if res is not None: | |||
@@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||
print( | |||
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" | |||
) | |||
match = compare_arguments_nested('unmatched lr_scheduler state_dict', | |||
lr_scheduler1['state_dict'], | |||
lr_scheduler2['state_dict']) and match | |||
match = compare_arguments_nested( | |||
'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], | |||
lr_scheduler2['state_dict'], **kwargs) and match | |||
res = compare_fn(cfg1, cfg2, None, 'cfg') | |||
if res is not None: | |||
print(f'cfg compared with user compare_fn with result:{res}\n') | |||
match = match and res | |||
else: | |||
match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match | |||
match = compare_arguments_nested( | |||
'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match | |||
res = compare_fn(state1, state2, None, 'state') | |||
if res is not None: | |||
@@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||
match = match and res | |||
else: | |||
match = compare_arguments_nested('unmatched random state', state1, | |||
state2) and match | |||
state2, **kwargs) and match | |||
return match |
@@ -19,7 +19,7 @@ moviepy>=1.0.3 | |||
networkx>=2.5 | |||
numba | |||
onnxruntime>=1.10 | |||
pai-easycv>=0.6.3.7 | |||
pai-easycv>=0.6.3.9 | |||
pandas | |||
psutil | |||
regex | |||
@@ -127,7 +127,7 @@ class HubOperationTest(unittest.TestCase): | |||
return None | |||
def test_list_model(self): | |||
data = self.api.list_model(TEST_MODEL_ORG) | |||
data = self.api.list_models(TEST_MODEL_ORG) | |||
assert len(data['Models']) >= 1 | |||
@@ -7,12 +7,12 @@ import uuid | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.constants import Licenses, ModelVisibility | |||
from modelscope.hub.errors import HTTPError, NotLoginException | |||
from modelscope.hub.repository import Repository | |||
from modelscope.hub.upload import upload_folder | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
from .test_utils import TEST_ACCESS_TOKEN1, delete_credential | |||
from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential | |||
logger = get_logger() | |||
@@ -22,7 +22,7 @@ class HubUploadTest(unittest.TestCase): | |||
def setUp(self): | |||
logger.info('SetUp') | |||
self.api = HubApi() | |||
self.user = os.environ.get('TEST_MODEL_ORG', 'citest') | |||
self.user = TEST_MODEL_ORG | |||
logger.info(self.user) | |||
self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload', | |||
uuid.uuid4().hex) | |||
@@ -39,7 +39,10 @@ class HubUploadTest(unittest.TestCase): | |||
def tearDown(self): | |||
logger.info('TearDown') | |||
shutil.rmtree(self.model_dir, ignore_errors=True) | |||
self.api.delete_model(model_id=self.create_model_name) | |||
try: | |||
self.api.delete_model(model_id=self.create_model_name) | |||
except Exception: | |||
pass | |||
def test_upload_exits_repo_master(self): | |||
logger.info('basic test for upload!') | |||
@@ -50,14 +53,14 @@ class HubUploadTest(unittest.TestCase): | |||
license=Licenses.APACHE_V2) | |||
os.system("echo '111'>%s" | |||
% os.path.join(self.finetune_path, 'add1.py')) | |||
upload_folder( | |||
self.api.push_model( | |||
model_id=self.create_model_name, model_dir=self.finetune_path) | |||
Repository(model_dir=self.repo_path, clone_from=self.create_model_name) | |||
assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) | |||
shutil.rmtree(self.repo_path, ignore_errors=True) | |||
os.system("echo '222'>%s" | |||
% os.path.join(self.finetune_path, 'add2.py')) | |||
upload_folder( | |||
self.api.push_model( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
revision='new_revision/version1') | |||
@@ -69,7 +72,7 @@ class HubUploadTest(unittest.TestCase): | |||
shutil.rmtree(self.repo_path, ignore_errors=True) | |||
os.system("echo '333'>%s" | |||
% os.path.join(self.finetune_path, 'add3.py')) | |||
upload_folder( | |||
self.api.push_model( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
revision='new_revision/version2', | |||
@@ -84,7 +87,7 @@ class HubUploadTest(unittest.TestCase): | |||
add4_path = os.path.join(self.finetune_path, 'temp') | |||
os.mkdir(add4_path) | |||
os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) | |||
upload_folder( | |||
self.api.push_model( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
revision='new_revision/version1') | |||
@@ -101,7 +104,7 @@ class HubUploadTest(unittest.TestCase): | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
os.system("echo '111'>%s" | |||
% os.path.join(self.finetune_path, 'add1.py')) | |||
upload_folder( | |||
self.api.push_model( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
revision='new_model_new_revision', | |||
@@ -119,48 +122,23 @@ class HubUploadTest(unittest.TestCase): | |||
logger.info('test upload without login!') | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
delete_credential() | |||
try: | |||
upload_folder( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
visibility=ModelVisibility.PUBLIC, | |||
license=Licenses.APACHE_V2) | |||
except Exception as e: | |||
logger.info(e) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
upload_folder( | |||
with self.assertRaises(NotLoginException): | |||
self.api.push_model( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
visibility=ModelVisibility.PUBLIC, | |||
license=Licenses.APACHE_V2) | |||
Repository( | |||
model_dir=self.repo_path, clone_from=self.create_model_name) | |||
assert os.path.exists( | |||
os.path.join(self.repo_path, 'configuration.json')) | |||
shutil.rmtree(self.repo_path, ignore_errors=True) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_upload_invalid_repo(self): | |||
logger.info('test upload to invalid repo!') | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
try: | |||
upload_folder( | |||
with self.assertRaises(HTTPError): | |||
self.api.push_model( | |||
model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), | |||
model_dir=self.finetune_path, | |||
visibility=ModelVisibility.PUBLIC, | |||
license=Licenses.APACHE_V2) | |||
except Exception as e: | |||
logger.info(e) | |||
upload_folder( | |||
model_id=self.create_model_name, | |||
model_dir=self.finetune_path, | |||
visibility=ModelVisibility.PUBLIC, | |||
license=Licenses.APACHE_V2) | |||
Repository( | |||
model_dir=self.repo_path, clone_from=self.create_model_name) | |||
assert os.path.exists( | |||
os.path.join(self.repo_path, 'configuration.json')) | |||
shutil.rmtree(self.repo_path, ignore_errors=True) | |||
if __name__ == '__main__': | |||
@@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_ms_csv_basic(self): | |||
ms_ds_train = MsDataset.load( | |||
'afqmc_small', namespace='userxiaoming', split='train') | |||
'clue', subset_name='afqmc', | |||
split='train').to_hf_dataset().select(range(5)) | |||
print(next(iter(ms_ds_train))) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@@ -45,6 +45,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_url_pytorch': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_url_tf': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
@@ -74,6 +78,170 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
} | |||
} | |||
all_models_info = [ | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1', | |||
'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_cn_en.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_cn_en.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_en.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_en.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_ru.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_ru.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_es.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_es.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_ko.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_ko.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_ja.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_ja.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online', | |||
'wav_path': 'data/test/audios/asr_example_id.wav' | |||
}, | |||
{ | |||
'model_group': 'damo', | |||
'model_id': | |||
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline', | |||
'wav_path': 'data/test/audios/asr_example_id.wav' | |||
}, | |||
] | |||
def setUp(self) -> None: | |||
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' | |||
@@ -90,7 +258,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
def run_pipeline(self, | |||
model_id: str, | |||
audio_in: Union[str, bytes], | |||
sr: int = 16000) -> Dict[str, Any]: | |||
sr: int = None) -> Dict[str, Any]: | |||
inference_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=model_id) | |||
@@ -136,33 +304,26 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
return audio, fs | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav_pytorch(self): | |||
"""run with single waveform file | |||
def test_run_with_pcm(self): | |||
"""run with wav data | |||
""" | |||
logger.info('Run ASR test with waveform file (pytorch)...') | |||
logger.info('Run ASR test with wav data (tensorflow)...') | |||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||
self.check_result('test_run_with_wav_pytorch', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_pcm_pytorch(self): | |||
"""run with wav data | |||
""" | |||
model_id=self.am_tf_model_id, audio_in=audio, sr=sr) | |||
self.check_result('test_run_with_pcm_tf', rec_result) | |||
logger.info('Run ASR test with wav data (pytorch)...') | |||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr) | |||
self.check_result('test_run_with_pcm_pytorch', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav_tf(self): | |||
def test_run_with_wav(self): | |||
"""run with single waveform file | |||
""" | |||
@@ -174,21 +335,14 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
model_id=self.am_tf_model_id, audio_in=wav_file_path) | |||
self.check_result('test_run_with_wav_tf', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_pcm_tf(self): | |||
"""run with wav data | |||
""" | |||
logger.info('Run ASR test with wav data (tensorflow)...') | |||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
logger.info('Run ASR test with waveform file (pytorch)...') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=audio, sr=sr) | |||
self.check_result('test_run_with_pcm_tf', rec_result) | |||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||
self.check_result('test_run_with_wav_pytorch', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_url_tf(self): | |||
def test_run_with_url(self): | |||
"""run with single url file | |||
""" | |||
@@ -198,6 +352,12 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
model_id=self.am_tf_model_id, audio_in=URL_FILE) | |||
self.check_result('test_run_with_url_tf', rec_result) | |||
logger.info('Run ASR test with url file (pytorch)...') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=URL_FILE) | |||
self.check_result('test_run_with_url_pytorch', rec_result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_wav_dataset_pytorch(self): | |||
"""run with datasets, and audio format is waveform | |||
@@ -217,7 +377,6 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
data.text # hypothesis text | |||
""" | |||
logger.info('Run ASR test with waveform dataset (pytorch)...') | |||
logger.info('Downloading waveform testsets file ...') | |||
dataset_path = download_and_untar( | |||
@@ -225,40 +384,38 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
LITTLE_TESTSETS_URL, self.workspace) | |||
dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||
logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||
logger.info('Run ASR test with waveform dataset (pytorch)...') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_wav_dataset_pytorch', rec_result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_wav_dataset_tf(self): | |||
"""run with datasets, and audio format is waveform | |||
datasets directory: | |||
<dataset_path> | |||
wav | |||
test # testsets | |||
xx.wav | |||
... | |||
dev # devsets | |||
yy.wav | |||
... | |||
train # trainsets | |||
zz.wav | |||
... | |||
transcript | |||
data.text # hypothesis text | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_all_models(self): | |||
"""run with all models | |||
""" | |||
logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||
logger.info('Downloading waveform testsets file ...') | |||
dataset_path = download_and_untar( | |||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||
LITTLE_TESTSETS_URL, self.workspace) | |||
dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||
logger.info('Run ASR test with all models') | |||
for item in self.all_models_info: | |||
model_id = item['model_group'] + '/' + item['model_id'] | |||
wav_path = item['wav_path'] | |||
rec_result = self.run_pipeline( | |||
model_id=model_id, audio_in=wav_path) | |||
if rec_result.__contains__(OutputKeys.TEXT): | |||
logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' ' | |||
+ ColorCodes.YELLOW | |||
+ str(rec_result[OutputKeys.TEXT]) | |||
+ ColorCodes.END) | |||
else: | |||
logger.info(ColorCodes.MAGENTA + str(rec_result) | |||
+ ColorCodes.END) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
@@ -26,6 +26,20 @@ class TranslationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
pipeline_ins = pipeline(self.task, model=model_id) | |||
print(pipeline_ins(input=inputs)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name_for_en2fr(self): | |||
model_id = 'damo/nlp_csanmt_translation_en2fr' | |||
inputs = 'When I was in my 20s, I saw my very first psychotherapy client.' | |||
pipeline_ins = pipeline(self.task, model=model_id) | |||
print(pipeline_ins(input=inputs)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name_for_fr2en(self): | |||
model_id = 'damo/nlp_csanmt_translation_fr2en' | |||
inputs = "Quand j'avais la vingtaine, j'ai vu mes tout premiers clients comme psychothérapeute." | |||
pipeline_ins = pipeline(self.task, model=model_id) | |||
print(pipeline_ins(input=inputs)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' | |||
@@ -4,22 +4,45 @@ import unittest | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class TinynasObjectDetectionTest(unittest.TestCase): | |||
class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = Tasks.image_object_detection | |||
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run(self): | |||
def test_run_airdet(self): | |||
tinynas_object_detection = pipeline( | |||
Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
result = tinynas_object_detection( | |||
'data/test/images/image_detection.jpg') | |||
print(result) | |||
@unittest.skip('will be enabled after damoyolo officially released') | |||
def test_run_damoyolo(self): | |||
tinynas_object_detection = pipeline( | |||
Tasks.image_object_detection, | |||
model='damo/cv_tinynas_object-detection_damoyolo') | |||
result = tinynas_object_detection( | |||
'data/test/images/image_detection.jpg') | |||
print(result) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.test_demo() | |||
self.compatibility_check() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_image_object_detection_auto_pipeline(self): | |||
test_image = 'data/test/images/image_detection.jpg' | |||
tinynas_object_detection = pipeline( | |||
Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
result = tinynas_object_detection(test_image) | |||
tinynas_object_detection.show_result(test_image, result, | |||
'demo_ret.jpg') | |||
if __name__ == '__main__': | |||
@@ -0,0 +1,71 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import glob | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import torch | |||
from modelscope.metainfo import Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import DownloadMode, LogKeys, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||
class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): | |||
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' | |||
def setUp(self): | |||
self.logger = get_logger() | |||
self.logger.info(('Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
def _train(self, tmp_dir): | |||
cfg_options = {'train.max_epochs': 2} | |||
trainer_name = Trainers.easycv | |||
train_dataset = MsDataset.load( | |||
dataset_name='face_2d_keypoints_dataset', | |||
namespace='modelscope', | |||
split='train', | |||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
eval_dataset = MsDataset.load( | |||
dataset_name='face_2d_keypoints_dataset', | |||
namespace='modelscope', | |||
split='train', | |||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
kwargs = dict( | |||
model=self.model_id, | |||
train_dataset=train_dataset, | |||
eval_dataset=eval_dataset, | |||
work_dir=tmp_dir, | |||
cfg_options=cfg_options) | |||
trainer = build_trainer(trainer_name, kwargs) | |||
trainer.train() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_single_gpu(self): | |||
temp_file_dir = tempfile.TemporaryDirectory() | |||
tmp_dir = temp_file_dir.name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
self._train(tmp_dir) | |||
results_files = os.listdir(tmp_dir) | |||
json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) | |||
self.assertEqual(len(json_files), 1) | |||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
temp_file_dir.cleanup() | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ | |||
calculate_fisher | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.data_utils import to_device | |||
from modelscope.utils.regress_test_utils import MsRegressTool | |||
from modelscope.utils.regress_test_utils import (MsRegressTool, | |||
compare_arguments_nested) | |||
from modelscope.utils.test_utils import test_level | |||
@@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
def test_trainer_repeatable(self): | |||
import torch # noqa | |||
def compare_fn(value1, value2, key, type): | |||
# Ignore the differences between optimizers of two torch versions | |||
if type != 'optimizer': | |||
return None | |||
match = (value1['type'] == value2['type']) | |||
shared_defaults = set(value1['defaults'].keys()).intersection( | |||
set(value2['defaults'].keys())) | |||
match = all([ | |||
compare_arguments_nested(f'Optimizer defaults {key} not match', | |||
value1['defaults'][key], | |||
value2['defaults'][key]) | |||
for key in shared_defaults | |||
]) and match | |||
match = (len(value1['state_dict']['param_groups']) == len( | |||
value2['state_dict']['param_groups'])) and match | |||
for group1, group2 in zip(value1['state_dict']['param_groups'], | |||
value2['state_dict']['param_groups']): | |||
shared_keys = set(group1.keys()).intersection( | |||
set(group2.keys())) | |||
match = all([ | |||
compare_arguments_nested( | |||
f'Optimizer param_groups {key} not match', group1[key], | |||
group2[key]) for key in shared_keys | |||
]) and match | |||
return match | |||
def cfg_modify_fn(cfg): | |||
cfg.task = 'nli' | |||
cfg['preprocessor'] = {'type': 'nli-tokenizer'} | |||
@@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
with self.regress_tool.monitor_ms_train( | |||
trainer, 'sbert-base-tnews', level='strict'): | |||
trainer, 'sbert-base-tnews', level='strict', | |||
compare_fn=compare_fn): | |||
trainer.train() | |||
def finetune(self, | |||
@@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer(self): | |||
kwargs = dict( | |||
model=self.model_id, | |||
@@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||
for i in range(2): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | |||
kwargs = dict( | |||
@@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
os.makedirs(self.tmp_dir) | |||
self.dataset = MsDataset.load( | |||
'afqmc_small', namespace='userxiaoming', split='train') | |||
'clue', subset_name='afqmc', | |||
split='train').to_hf_dataset().select(range(2)) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir) | |||
@@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
pipeline_sentence_similarity(output_dir) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level') | |||
def test_trainer_with_backbone_head(self): | |||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
kwargs = dict( | |||
@@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
cfg = read_config(model_id, revision='beta') | |||
cfg.train.max_epochs = 20 | |||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||
cfg.train.work_dir = self.tmp_dir | |||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
cfg.dump(cfg_file) | |||
@@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
self.assertTrue(Metrics.accuracy in eval_results) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_configured_datasets(self): | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
cfg: Config = read_config(model_id) | |||
cfg.train.max_epochs = 20 | |||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||
cfg.train.work_dir = self.tmp_dir | |||
cfg.dataset = { | |||
'train': { | |||
'name': 'afqmc_small', | |||
'name': 'clue', | |||
'subset_name': 'afqmc', | |||
'split': 'train', | |||
'namespace': 'userxiaoming' | |||
}, | |||
'val': { | |||
'name': 'afqmc_small', | |||
'name': 'clue', | |||
'subset_name': 'afqmc', | |||
'split': 'train', | |||
'namespace': 'userxiaoming' | |||
}, | |||
} | |||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
@@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
cfg: Config = read_config(model_id) | |||
cfg.train.max_epochs = 3 | |||
cfg.preprocessor.first_sequence = 'sentence1' | |||
cfg.preprocessor.second_sequence = 'sentence2' | |||
cfg.preprocessor.label = 'label' | |||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||
cfg.train.work_dir = self.tmp_dir | |||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
cfg.dump(cfg_file) | |||
@@ -0,0 +1,19 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
class CompatibilityTest(unittest.TestCase): | |||
def setUp(self): | |||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
def tearDown(self): | |||
super().tearDown() | |||
def test_xtcocotools(self): | |||
from xtcocotools.coco import COCO | |||
if __name__ == '__main__': | |||
unittest.main() |