Browse Source

[to #42966122] requirements enchanment and self-host repo support

* add self-hosted repo:
* add extra requirements for different field and reduce necessary requirements
* update docker file with so required by audio
* add requirements checker which will be used later when implement lazy import
* remove repeated requirements and replace opencv-python-headless with opencv-python

example usage:
```shell
pip install model_scope[all] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/repo.html
pip install model_scope[cv] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/repo.html
pip install model_scope[nlp] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/repo.html
pip install model_scope[audio] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/repo.html
pip install model_scope[multi-modal] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/repo.html

```
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9211383
master
wenmeng.zwm 3 years ago
parent
commit
8e51a073a6
23 changed files with 580 additions and 151 deletions
  1. +7
    -0
      .dev_scripts/run_docker.sh
  2. +2
    -1
      docker/pytorch.dockerfile
  3. +3
    -0
      docs/source/index.rst
  4. +23
    -8
      modelscope/models/__init__.py
  5. +16
    -3
      modelscope/pipelines/audio/__init__.py
  6. +18
    -5
      modelscope/pipelines/cv/__init__.py
  7. +9
    -3
      modelscope/pipelines/multi_modal/__init__.py
  8. +12
    -6
      modelscope/pipelines/nlp/__init__.py
  9. +10
    -3
      modelscope/preprocessors/__init__.py
  10. +79
    -0
      modelscope/utils/check_requirements.py
  11. +3
    -2
      modelscope/utils/config.py
  12. +13
    -0
      modelscope/utils/constant.py
  13. +324
    -0
      modelscope/utils/import_utils.py
  14. +0
    -90
      modelscope/utils/pymod.py
  15. +15
    -5
      modelscope/utils/registry.py
  16. +0
    -5
      requirements.txt
  17. +2
    -5
      requirements/audio.txt
  18. +2
    -4
      requirements/multi-modal.txt
  19. +1
    -1
      requirements/nlp.txt
  20. +0
    -6
      requirements/pipeline.txt
  21. +6
    -4
      requirements/runtime.txt
  22. +13
    -0
      setup.py
  23. +22
    -0
      tests/utils/test_check_requirements.py

+ 7
- 0
.dev_scripts/run_docker.sh View File

@@ -0,0 +1,7 @@
#sudo docker run --name zwm_maas -v /home/wenmeng.zwm/workspace:/home/wenmeng.zwm/workspace --net host -ti reg.docker.alibaba-inc.com/pai-dlc/tensorflow-training:2.3-gpu-py36-cu101-ubuntu18.04 bash
#sudo docker run --name zwm_maas_pytorch -v /home/wenmeng.zwm/workspace:/home/wenmeng.zwm/workspace --net host -ti reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 bash
CONTAINER_NAME=modelscope-dev
IMAGE_NAME=registry.cn-shanghai.aliyuncs.com/modelscope/modelscope
IMAGE_VERSION=v0.1.1-16-g62856fa-devel
MOUNT_DIR=/home/wenmeng.zwm/workspace
sudo docker run --name $CONTAINER_NAME -v $MOUNT_DIR:$MOUNT_DIR --net host -ti ${IMAGE_NAME}:${IMAGE_VERSION} bash

+ 2
- 1
docker/pytorch.dockerfile View File

@@ -30,7 +30,8 @@ RUN apt-get update &&\
zip \ zip \
zlib1g-dev \ zlib1g-dev \
unzip \ unzip \
pkg-config
pkg-config \
libsndfile1


# install modelscope and its python env # install modelscope and its python env
WORKDIR /opt/modelscope WORKDIR /opt/modelscope


+ 3
- 0
docs/source/index.rst View File

@@ -13,6 +13,7 @@ ModelScope doc


quick_start.md quick_start.md
develop.md develop.md
faq.md


.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
@@ -20,6 +21,8 @@ ModelScope doc


tutorials/index tutorials/index




.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
:caption: Changelog :caption: Changelog


+ 23
- 8
modelscope/models/__init__.py View File

@@ -1,11 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.

from .audio.ans.frcrn import FRCRNModel
from .audio.kws import GenericKeyWordSpotting
from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k
from .base import Model from .base import Model
from .builder import MODELS, build_model from .builder import MODELS, build_model
from .multi_modal import OfaForImageCaptioning
from .nlp import (BertForSequenceClassification, SbertForSentenceSimilarity,
SbertForZeroShotClassification)

try:
from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k

except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
pass
else:
raise ModuleNotFoundError(e)

try:
from .audio.kws import GenericKeyWordSpotting
from .multi_modal import OfaForImageCaptioning
from .nlp import (BertForSequenceClassification,
SbertForSentenceSimilarity,
SbertForZeroShotClassification)
from .audio.ans.frcrn import FRCRNModel
except ModuleNotFoundError as e:
if str(e) == "No module named 'pytorch'":
pass
else:
raise ModuleNotFoundError(e)

+ 16
- 3
modelscope/pipelines/audio/__init__.py View File

@@ -1,3 +1,16 @@
from .kws_kwsbp_pipeline import * # noqa F403
from .linear_aec_pipeline import LinearAECPipeline
from .text_to_speech_pipeline import * # noqa F403
try:
from .kws_kwsbp_pipeline import * # noqa F403
from .linear_aec_pipeline import LinearAECPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass
else:
raise ModuleNotFoundError(e)

try:
from .text_to_speech_pipeline import * # noqa F403
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
pass
else:
raise ModuleNotFoundError(e)

+ 18
- 5
modelscope/pipelines/cv/__init__.py View File

@@ -1,5 +1,18 @@
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
try:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass
else:
raise ModuleNotFoundError(e)

try:
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
pass
else:
raise ModuleNotFoundError(e)

+ 9
- 3
modelscope/pipelines/multi_modal/__init__.py View File

@@ -1,3 +1,9 @@
from .image_captioning_pipeline import ImageCaptionPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
try:
from .image_captioning_pipeline import ImageCaptionPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass
else:
raise ModuleNotFoundError(e)

+ 12
- 6
modelscope/pipelines/nlp/__init__.py View File

@@ -1,6 +1,12 @@
from .fill_mask_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import * # noqa F403
try:
from .fill_mask_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import * # noqa F403
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass
else:
raise ModuleNotFoundError(e)

+ 10
- 3
modelscope/preprocessors/__init__.py View File

@@ -1,11 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


from .audio import LinearAECAndFbank
from .base import Preprocessor from .base import Preprocessor
from .builder import PREPROCESSORS, build_preprocessor from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose from .common import Compose
from .image import LoadImage, load_image from .image import LoadImage, load_image
from .kws import WavToLists from .kws import WavToLists
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403
from .text_to_speech import * # noqa F403 from .text_to_speech import * # noqa F403

try:
from .audio import LinearAECAndFbank
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
pass
else:
raise ModuleNotFoundError(e)

+ 79
- 0
modelscope/utils/check_requirements.py View File

@@ -0,0 +1,79 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from modelscope.utils.constant import Fields, Requirements
from modelscope.utils.import_utils import requires


def get_msg(field):
msg = f'\n{field} requirements not installed, please execute ' \
f'`pip install requirements/{field}.txt` or ' \
f'`pip install modelscope[{field}]`'
return msg


class NLPModuleNotFoundError(ModuleNotFoundError):

def __init__(self, e: ModuleNotFoundError) -> None:
e.msg += get_msg(Fields.nlp)
super().__init__(e)


class CVModuleNotFoundError(ModuleNotFoundError):

def __init__(self, e: ModuleNotFoundError) -> None:
e.msg += get_msg(Fields.cv)
super().__init__(e)


class AudioModuleNotFoundError(ModuleNotFoundError):

def __init__(self, e: ModuleNotFoundError) -> None:
e.msg += get_msg(Fields.audio)
super().__init__(e)


class MultiModalModuleNotFoundError(ModuleNotFoundError):

def __init__(self, e: ModuleNotFoundError) -> None:
e.msg += get_msg(Fields.multi_modal)
super().__init__(e)


def check_nlp():
try:
requires('nlp models', (
Requirements.torch,
Requirements.tokenizers,
))
except ImportError as e:
raise NLPModuleNotFoundError(e)


def check_cv():
try:
requires('cv models', (
Requirements.torch,
Requirements.tokenizers,
))
except ImportError as e:
raise CVModuleNotFoundError(e)


def check_audio():
try:
requires('audio models', (
Requirements.torch,
Requirements.tf,
))
except ImportError as e:
raise AudioModuleNotFoundError(e)


def check_multi_modal():
try:
requires('multi-modal models', (
Requirements.torch,
Requirements.tokenizers,
))
except ImportError as e:
raise MultiModalModuleNotFoundError(e)

+ 3
- 2
modelscope/utils/config.py View File

@@ -17,9 +17,10 @@ from typing import Dict
import addict import addict
from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode


from modelscope.utils.import_utils import (import_modules,
import_modules_from_file,
validate_py_syntax)
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.pymod import (import_modules, import_modules_from_file,
validate_py_syntax)


if platform.system() == 'Windows': if platform.system() == 'Windows':
import regex as re # type: ignore import regex as re # type: ignore


+ 13
- 0
modelscope/utils/constant.py View File

@@ -97,5 +97,18 @@ class ModelFile(object):
TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' TORCH_MODEL_BIN_FILE = 'pytorch_model.bin'




class Requirements(object):
"""Requirement names for each module
"""
protobuf = 'protobuf'
sentencepiece = 'sentencepiece'
sklearn = 'sklearn'
scipy = 'scipy'
timm = 'timm'
tokenizers = 'tokenizers'
tf = 'tf'
torch = 'torch'


TENSORFLOW = 'tensorflow' TENSORFLOW = 'tensorflow'
PYTORCH = 'pytorch' PYTORCH = 'pytorch'

+ 324
- 0
modelscope/utils/import_utils.py View File

@@ -0,0 +1,324 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import ast
import functools
import importlib.util
import os
import os.path as osp
import sys
import types
from collections import OrderedDict
from functools import wraps
from importlib import import_module
from itertools import chain
from types import ModuleType
from typing import Any

import json
from packaging import version

from modelscope.utils.constant import Fields
from modelscope.utils.logger import get_logger

if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata

logger = get_logger()


def import_modules_from_file(py_file: str):
""" Import module from a certrain file

Args:
py_file: path to a python file to be imported

Return:

"""
dirname, basefile = os.path.split(py_file)
if dirname == '':
dirname == './'
module_name = osp.splitext(basefile)[0]
sys.path.insert(0, dirname)
validate_py_syntax(py_file)
mod = import_module(module_name)
sys.path.pop(0)
return module_name, mod


def import_modules(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.

Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.

Returns:
list[module] | module | None: The imported modules.

Examples:
>>> osp, sys = import_modules(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
logger.warning(f'{imp} failed to import and is ignored.')
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported


def validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')


# following code borrows implementation from huggingface/transformers
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
_torch_version = 'N/A'
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec('torch') is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version('torch')
logger.info(f'PyTorch version {_torch_version} available.')
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info('Disabling PyTorch because USE_TF is set')
_torch_available = False

_tf_version = 'N/A'
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec('tensorflow') is not None
if _tf_available:
candidates = (
'tensorflow',
'tensorflow-cpu',
'tensorflow-gpu',
'tf-nightly',
'tf-nightly-cpu',
'tf-nightly-gpu',
'intel-tensorflow',
'intel-tensorflow-avx512',
'tensorflow-rocm',
'tensorflow-macos',
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse('2'):
pass
else:
logger.info(f'TensorFlow version {_tf_version} available.')
else:
logger.info('Disabling Tensorflow because USE_TORCH is set')
_tf_available = False

_timm_available = importlib.util.find_spec('timm') is not None
try:
_timm_version = importlib_metadata.version('timm')
logger.debug(f'Successfully imported timm version {_timm_version}')
except importlib_metadata.PackageNotFoundError:
_timm_available = False


def is_scipy_available():
return importlib.util.find_spec('scipy') is not None


def is_sklearn_available():
if importlib.util.find_spec('sklearn') is None:
return False
return is_scipy_available() and importlib.util.find_spec('sklearn.metrics')


def is_sentencepiece_available():
return importlib.util.find_spec('sentencepiece') is not None


def is_protobuf_available():
if importlib.util.find_spec('google') is None:
return False
return importlib.util.find_spec('google.protobuf') is not None


def is_tokenizers_available():
return importlib.util.find_spec('tokenizers') is not None


def is_timm_available():
return _timm_available


def is_torch_available():
return _torch_available


def is_torch_cuda_available():
if is_torch_available():
import torch

return torch.cuda.is_available()
else:
return False


def is_tf_available():
return _tf_available


# docstyle-ignore
PROTOBUF_IMPORT_ERROR = """
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and
follow the ones that match your environment.
"""

# docstyle-ignore
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment.
"""

# docstyle-ignore
SKLEARN_IMPORT_ERROR = """
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
```
pip install -U scikit-learn
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install -U scikit-learn
```
"""

# docstyle-ignore
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
"""

# docstyle-ignore
TIMM_IMPORT_ERROR = """
{0} requires the timm library but it was not found in your environment. You can install it with pip:
`pip install timm`
"""

# docstyle-ignore
TOKENIZERS_IMPORT_ERROR = """
{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
```
pip install tokenizers
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install tokenizers
```
"""

# docstyle-ignore
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""

# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
`pip install scipy`
"""

REQUIREMENTS_MAAPING = OrderedDict([
('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
('sentencepiece', (is_sentencepiece_available,
SENTENCEPIECE_IMPORT_ERROR)),
('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
])


def requires(obj, requirements):
if not isinstance(requirements, (list, tuple)):
requirements = [requirements]
if isinstance(obj, str):
name = obj
else:
name = obj.__name__ if hasattr(obj,
'__name__') else obj.__class__.__name__
checks = (REQUIREMENTS_MAAPING[req] for req in requirements)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError(''.join(failed))


def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f'Method `{func.__name__}` requires PyTorch.')

return wrapper


def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f'Method `{func.__name__}` requires TF.')

return wrapper

+ 0
- 90
modelscope/utils/pymod.py View File

@@ -1,90 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import ast
import os
import os.path as osp
import sys
import types
from importlib import import_module

from modelscope.utils.logger import get_logger

logger = get_logger()


def import_modules_from_file(py_file: str):
""" Import module from a certrain file

Args:
py_file: path to a python file to be imported

Return:

"""
dirname, basefile = os.path.split(py_file)
if dirname == '':
dirname == './'
module_name = osp.splitext(basefile)[0]
sys.path.insert(0, dirname)
validate_py_syntax(py_file)
mod = import_module(module_name)
sys.path.pop(0)
return module_name, mod


def import_modules(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.

Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.

Returns:
list[module] | module | None: The imported modules.

Examples:
>>> osp, sys = import_modules(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
logger.warning(f'{imp} failed to import and is ignored.')
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported


def validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')

+ 15
- 5
modelscope/utils/registry.py View File

@@ -1,7 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import inspect import inspect
from typing import List, Tuple, Union


from modelscope.utils.import_utils import requires
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


default_group = 'default' default_group = 'default'
@@ -52,9 +54,14 @@ class Registry(object):
def _register_module(self, def _register_module(self,
group_key=default_group, group_key=default_group,
module_name=None, module_name=None,
module_cls=None):
module_cls=None,
requirements=None):
assert isinstance(group_key, assert isinstance(group_key,
str), 'group_key is required and must be str' str), 'group_key is required and must be str'

if requirements is not None:
requires(module_cls, requirements)

if group_key not in self._modules: if group_key not in self._modules:
self._modules[group_key] = dict() self._modules[group_key] = dict()


@@ -86,7 +93,8 @@ class Registry(object):
def register_module(self, def register_module(self,
group_key: str = default_group, group_key: str = default_group,
module_name: str = None, module_name: str = None,
module_cls: type = None):
module_cls: type = None,
requirements: Union[List, Tuple] = None):
""" Register module """ Register module


Example: Example:
@@ -110,17 +118,18 @@ class Registry(object):
default group name is 'default' default group name is 'default'
module_name: Module name module_name: Module name
module_cls: Module class object module_cls: Module class object
requirements: Module necessary requirements


""" """
if not (module_name is None or isinstance(module_name, str)): if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str,' raise TypeError(f'module_name must be either of None, str,'
f'got {type(module_name)}') f'got {type(module_name)}')

if module_cls is not None: if module_cls is not None:
self._register_module( self._register_module(
group_key=group_key, group_key=group_key,
module_name=module_name, module_name=module_name,
module_cls=module_cls)
module_cls=module_cls,
requirements=requirements)
return module_cls return module_cls


# if module_cls is None, should return a decorator function # if module_cls is None, should return a decorator function
@@ -128,7 +137,8 @@ class Registry(object):
self._register_module( self._register_module(
group_key=group_key, group_key=group_key,
module_name=module_name, module_name=module_name,
module_cls=module_cls)
module_cls=module_cls,
requirements=requirements)
return module_cls return module_cls


return _register return _register


+ 0
- 5
requirements.txt View File

@@ -1,6 +1 @@
-r requirements/runtime.txt -r requirements/runtime.txt
-r requirements/pipeline.txt
-r requirements/multi-modal.txt
-r requirements/nlp.txt
-r requirements/audio.txt
-r requirements/cv.txt

+ 2
- 5
requirements/audio.txt View File

@@ -1,10 +1,5 @@
#tts #tts
h5py h5py
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/pytorch_wavelets-1.3.0-py3-none-any.whl
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp36-cp36m-linux_x86_64.whl; python_version=='3.6'
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp37-cp37m-linux_x86_64.whl; python_version=='3.7'
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp38-cp38-linux_x86_64.whl; python_version=='3.8'
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp39-cp39-linux_x86_64.whl; python_version=='3.9'
inflect inflect
keras keras
librosa librosa
@@ -14,6 +9,7 @@ nara_wpe
numpy numpy
protobuf>3,<=3.20 protobuf>3,<=3.20
ptflops ptflops
pytorch_wavelets==1.3.0
PyWavelets>=1.0.0 PyWavelets>=1.0.0
scikit-learn scikit-learn
SoundFile>0.10 SoundFile>0.10
@@ -24,4 +20,5 @@ torch
torchaudio torchaudio
torchvision torchvision
tqdm tqdm
ttsfrd==0.0.2
unidecode unidecode

+ 2
- 4
requirements/multi-modal.txt View File

@@ -1,8 +1,6 @@
datasets
einops
fairseq==maas
ftfy>=6.0.3 ftfy>=6.0.3
https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/fairseq-maas-py3-none-any.whl
https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/ofa-0.0.2-py3-none-any.whl
ofa==0.0.2
pycocoevalcap>=1.2 pycocoevalcap>=1.2
pycocotools>=2.0.4 pycocotools>=2.0.4
rouge_score rouge_score


+ 1
- 1
requirements/nlp.txt View File

@@ -1 +1 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.4.2-py3-none-any.whl
sofa==1.0.4.2

+ 0
- 6
requirements/pipeline.txt View File

@@ -1,6 +0,0 @@
#https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.4-py2.py3-none-any.whl
# tensorflow
#--find-links https://download.pytorch.org/whl/torch_stable.html
# torch<1.10,>=1.8.0
# torchaudio
# torchvision

+ 6
- 4
requirements/runtime.txt View File

@@ -1,16 +1,18 @@
addict addict
datasets datasets
easydict easydict
einops
filelock>=3.3.0 filelock>=3.3.0
numpy numpy
opencv-python-headless
opencv-python
Pillow>=6.2.0 Pillow>=6.2.0
protobuf>3,<=3.20
pyyaml pyyaml
requests requests
requests==2.27.1
scipy scipy
setuptools==58.0.4
setuptools
tokenizers<=0.10.3 tokenizers<=0.10.3
torch
tqdm>=4.64.0 tqdm>=4.64.0
transformers<=4.16.2
transformers<=4.16.2,>=4.10.3
yapf yapf

+ 13
- 0
setup.py View File

@@ -5,6 +5,8 @@ import shutil
import subprocess import subprocess
from setuptools import find_packages, setup from setuptools import find_packages, setup


from modelscope.utils.constant import Fields



def readme(): def readme():
with open('README.md', encoding='utf-8') as f: with open('README.md', encoding='utf-8') as f:
@@ -169,6 +171,16 @@ if __name__ == '__main__':
pack_resource() pack_resource()
os.chdir('package') os.chdir('package')
install_requires, deps_link = parse_requirements('requirements.txt') install_requires, deps_link = parse_requirements('requirements.txt')
extra_requires = {}
all_requires = []
for field in dir(Fields):
if field.startswith('_'):
continue
extra_requires[field], _ = parse_requirements(
f'requirements/{field}.txt')
all_requires.append(extra_requires[field])
extra_requires['all'] = all_requires

setup( setup(
name='model-scope', name='model-scope',
version=get_version(), version=get_version(),
@@ -193,5 +205,6 @@ if __name__ == '__main__':
license='Apache License 2.0', license='Apache License 2.0',
tests_require=parse_requirements('requirements/tests.txt'), tests_require=parse_requirements('requirements/tests.txt'),
install_requires=install_requires, install_requires=install_requires,
extras_require=extra_requires,
dependency_links=deps_link, dependency_links=deps_link,
zip_safe=False) zip_safe=False)

+ 22
- 0
tests/utils/test_check_requirements.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest
from typing import List, Union

from modelscope.utils.check_requirements import NLPModuleNotFoundError, get_msg
from modelscope.utils.constant import Fields


class ImportUtilsTest(unittest.TestCase):

def test_type_module_not_found(self):
with self.assertRaises(NLPModuleNotFoundError) as ctx:
try:
import not_found
except ModuleNotFoundError as e:
raise NLPModuleNotFoundError(e)
self.assertTrue(get_msg(Fields.nlp) in ctx.exception.msg.msg)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save