Browse Source

add finetune & merge master

master
翎航 2 years ago
parent
commit
eb82ba9c6f
45 changed files with 1001 additions and 309 deletions
  1. +2
    -0
      .dev_scripts/dockerci.sh
  2. +18
    -8
      README.md
  3. +2
    -3
      modelscope/exporters/torch_model_exporter.py
  4. +28
    -4
      modelscope/hub/api.py
  5. +1
    -0
      modelscope/hub/constants.py
  6. +15
    -0
      modelscope/hub/utils/utils.py
  7. +9
    -2
      modelscope/models/audio/kws/farfield/model.py
  8. +2
    -0
      modelscope/models/base/base_model.py
  9. +4
    -2
      modelscope/models/multi_modal/clip/model.py
  10. +13
    -0
      modelscope/models/multi_modal/ofa/configuration_ofa.py
  11. +203
    -141
      modelscope/models/multi_modal/ofa/modeling_ofa.py
  12. +40
    -0
      modelscope/models/multi_modal/ofa/utils/utils.py
  13. +155
    -0
      modelscope/models/multi_modal/ofa/vit.py
  14. +22
    -3
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  15. +3
    -0
      modelscope/models/science/unifold/dataset.py
  16. +3
    -0
      modelscope/models/science/unifold/model.py
  17. +3
    -0
      modelscope/models/science/unifold/modules/__init__.py
  18. +2
    -0
      modelscope/msdatasets/ms_dataset.py
  19. +19
    -15
      modelscope/outputs/outputs.py
  20. +4
    -1
      modelscope/pipelines/base.py
  21. +2
    -3
      modelscope/pipelines/builder.py
  22. +3
    -3
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  23. +205
    -7
      modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py
  24. +2
    -2
      modelscope/pipelines/nlp/token_classification_pipeline.py
  25. +3
    -3
      modelscope/pipelines/nlp/word_segmentation_pipeline.py
  26. +1
    -1
      modelscope/preprocessors/multi_modal.py
  27. +12
    -5
      modelscope/preprocessors/nlp/nlp_base.py
  28. +82
    -69
      modelscope/preprocessors/nlp/token_classification_preprocessor.py
  29. +2
    -2
      modelscope/preprocessors/ofa/ocr_recognition.py
  30. +5
    -2
      modelscope/trainers/audio/kws_farfield_trainer.py
  31. +11
    -11
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  32. +1
    -1
      modelscope/trainers/nlp/text_generation_trainer.py
  33. +4
    -2
      modelscope/trainers/nlp_trainer.py
  34. +6
    -1
      modelscope/trainers/trainer.py
  35. +8
    -0
      modelscope/utils/constant.py
  36. +65
    -0
      modelscope/utils/cv/image_utils.py
  37. +2
    -1
      requirements/framework.txt
  38. +2
    -0
      requirements/multi-modal.txt
  39. +2
    -0
      requirements/science.txt
  40. +6
    -2
      tests/msdatasets/test_dataset_upload.py
  41. +2
    -1
      tests/outputs/test_model_outputs.py
  42. +17
    -12
      tests/pipelines/test_face_2d_keypoints.py
  43. +8
    -0
      tests/pipelines/test_named_entity_recognition.py
  44. +1
    -1
      tests/pipelines/test_unifold.py
  45. +1
    -1
      tests/trainers/test_finetune_token_classificatin.py

+ 2
- 0
.dev_scripts/dockerci.sh View File

@@ -37,6 +37,7 @@ do
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_ENVIRONMENT='ci' \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \
@@ -59,6 +60,7 @@ do
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_ENVIRONMENT='ci' \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \


+ 18
- 8
README.md View File

@@ -1,16 +1,26 @@
# Introduction

ModelScope library is targeted to support training, evaluation and inference for the state of the art models provided by Mind and further support third-party models provided by users outside alibaba.
[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bringing together most advanced machine learning models from the AI community, and to streamlining the process of leveraging and applying AI models . The core ModelScope library enables developers to perform model inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains.

# Design doc
The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done within only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary.

Please refer to alidoc [link](https://alidocs.dingtalk.com/i/nodes/OBldywvrKxo89xmAO05yJQk2ngpNbLz4?nav=spaces&navQuery=spaceId%3Dnb9XJNlZxbgrOXyA&iframeQuery=utm_source%3Dportal%26utm_medium%3Dportal_space_file_tree)
Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with the backend services of ModelScope, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate various entity (models and datasets) management to be performed seamlessly under-the-hood, such as entity lookup, version control, and cache management.

# Development doc
# Installation

Please refer to [develop.md](docs/source/develop.md)
Please refer to [installation](https://modelscope.cn/docs/%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85).

# ChangeLog
* 20/05/2022 First release version
# Get Started

Refer to [change_log.md](docs/source/change_log.md) for more details
You can refer to [quick_start](https://modelscope.cn/docs/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B) for quick start.

We also provide other documentations including:
* [Introduction to tasks](https://modelscope.cn/docs/%E4%BB%BB%E5%8A%A1%E7%9A%84%E4%BB%8B%E7%BB%8D)
* [Use pipeline for model inference](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8E%A8%E7%90%86Pipeline)
* [Finetune example](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AE%AD%E7%BB%83Train)
* [Preprocessing of data](https://modelscope.cn/docs/%E6%95%B0%E6%8D%AE%E7%9A%84%E9%A2%84%E5%A4%84%E7%90%86)
* [Evaluation metrics](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AF%84%E4%BC%B0)

# License

This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).

+ 2
- 3
modelscope/exporters/torch_model_exporter.py View File

@@ -128,7 +128,7 @@ class TorchModelExporter(Exporter):
args_list = list(args)
else:
args_list = [args]
if isinstance(args_list[-1], dict):
if isinstance(args_list[-1], Mapping):
args_dict = args_list[-1]
args_list = args_list[:-1]
n_nonkeyword = len(args_list)
@@ -284,9 +284,8 @@ class TorchModelExporter(Exporter):
'Model property dummy_inputs must be set.')
dummy_inputs = collate_fn(dummy_inputs, device)
if isinstance(dummy_inputs, Mapping):
dummy_inputs = self._decide_input_format(model, dummy_inputs)
dummy_inputs_filter = []
for _input in dummy_inputs:
for _input in self._decide_input_format(model, dummy_inputs):
if _input is not None:
dummy_inputs_filter.append(_input)
else:


+ 28
- 4
modelscope/hub/api.py View File

@@ -23,7 +23,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_MESSAGE,
API_RESPONSE_FIELD_USERNAME,
DEFAULT_CREDENTIALS_PATH,
MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS,
MODELSCOPE_ENVIRONMENT,
MODELSCOPE_USERNAME, ONE_YEAR_SECONDS,
Licenses, ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, NoValidRevisionError,
@@ -38,8 +39,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
DEFAULT_REPOSITORY_REVISION,
MASTER_MODEL_BRANCH, DatasetFormations,
DatasetMetaFormats, DownloadMode,
ModelFile)
DatasetMetaFormats, DownloadChannel,
DownloadMode, ModelFile)
from modelscope.utils.logger import get_logger
from .utils.utils import (get_endpoint, get_release_datetime,
model_id_to_group_owner_name)
@@ -645,6 +646,25 @@ class HubApi:
def check_local_cookies(self, use_cookies) -> CookieJar:
return self._check_cookie(use_cookies=use_cookies)

def dataset_download_uv(self, dataset_name: str, namespace: str):
if not dataset_name or not namespace:
raise ValueError('dataset_name or namespace cannot be empty!')

# get channel and user_name
channel = DownloadChannel.LOCAL.value
user_name = ''
if MODELSCOPE_ENVIRONMENT in os.environ:
channel = os.environ[MODELSCOPE_ENVIRONMENT]
if MODELSCOPE_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_USERNAME]

url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}'
cookies = ModelScopeConfig.get_cookies()
r = requests.post(url, cookies=cookies, headers=self.headers)
resp = r.json()
raise_on_error(resp)
return resp['Message']


class ModelScopeConfig:
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
@@ -760,14 +780,18 @@ class ModelScopeConfig:
env = 'custom'
if MODELSCOPE_ENVIRONMENT in os.environ:
env = os.environ[MODELSCOPE_ENVIRONMENT]
user_name = 'unknown'
if MODELSCOPE_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_USERNAME]

ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % (
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
__version__,
platform.python_version(),
ModelScopeConfig.get_user_session_id(),
platform.platform(),
platform.processor(),
env,
user_name,
)
if isinstance(user_agent, dict):
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())


+ 1
- 0
modelscope/hub/constants.py View File

@@ -18,6 +18,7 @@ API_RESPONSE_FIELD_EMAIL = 'Email'
API_RESPONSE_FIELD_MESSAGE = 'Message'
MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60




+ 15
- 0
modelscope/hub/utils/utils.py View File

@@ -5,6 +5,8 @@ import os
from datetime import datetime
from typing import Optional

import requests

from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
@@ -85,3 +87,16 @@ def file_integrity_validation(file_path, expected_sha256):
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
logger.error(msg)
raise FileIntegrityError(msg)


def create_library_statistics(method: str, name: str, cn_name: Optional[str]):
try:
from modelscope.hub.api import ModelScopeConfig
path = f'{get_endpoint()}/api/v1/statistics/library'
headers = {'user-agent': ModelScopeConfig.get_user_agent()}
params = {'Method': method, 'Name': name, 'CnName': cn_name}
r = requests.post(path, params=params, headers=headers)
r.raise_for_status()
except Exception:
pass
return

+ 9
- 2
modelscope/models/audio/kws/farfield/model.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import tempfile
from typing import Dict, Optional

from modelscope.metainfo import Models
@@ -36,12 +37,15 @@ class FSMNSeleNetV2Decorator(TorchModel):
else:
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
self.tmp_dir = tempfile.TemporaryDirectory()
new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG)

self._sc = None
if os.path.exists(model_txt_file):
conf_dict = dict(mode=56542, kws_model=model_txt_file)
update_conf(sc_config_file, sc_config_file, conf_dict)
update_conf(sc_config_file, new_config_file, conf_dict)
import py_sound_connect
self._sc = py_sound_connect.SoundConnect(sc_config_file)
self._sc = py_sound_connect.SoundConnect(new_config_file)
self.size_in = self._sc.bytesPerBlockIn()
self.size_out = self._sc.bytesPerBlockOut()
else:
@@ -49,6 +53,9 @@ class FSMNSeleNetV2Decorator(TorchModel):
f'Invalid model directory! Failed to load model file: {model_txt_file}.'
)

def __del__(self):
self.tmp_dir.cleanup()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.model.forward(input)



+ 2
- 0
modelscope/models/base/base_model.py View File

@@ -131,6 +131,8 @@ class Model(ABC):

if not hasattr(model, 'cfg'):
model.cfg = cfg

model.name = model_name_or_path
return model

def save_pretrained(self,


+ 4
- 2
modelscope/models/multi_modal/clip/model.py View File

@@ -349,11 +349,13 @@ class CLIP(nn.Module):
text_num_hidden_layers: int,
text_type_vocab_size: int,
tokenizer: FullTokenizer,
# vision_head_width, added this param for ViT-H
vision_head_width: int = 64,
):
super().__init__()

if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
vision_heads = vision_width * 32 // vision_head_width
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
@@ -361,7 +363,7 @@ class CLIP(nn.Module):
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
vision_heads = vision_width // vision_head_width
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,


+ 13
- 0
modelscope/models/multi_modal/ofa/configuration_ofa.py View File

@@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig):
entangle_position_embedding=False,
interpolate_position=False,
orig_patch_image_size=224,
share_attn_bias=False,
use_image_feature=True,
disable_entangle=False,
use_ofasys=False,
vit_type='vit_base',
vit_drop_path_rate=0.0,
**kwargs):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
@@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig):
self.interpolate_position = interpolate_position
self.orig_patch_image_size = orig_patch_image_size

self.share_attn_bias = share_attn_bias
self.use_image_feature = use_image_feature
self.disable_entangle = disable_entangle
self.use_ofasys = use_ofasys
self.vit_type = vit_type
self.vit_drop_path_rate = vit_drop_path_rate

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,


+ 203
- 141
modelscope/models/multi_modal/ofa/modeling_ofa.py View File

@@ -35,6 +35,8 @@ from transformers.utils import logging
from .configuration_ofa import OFAConfig
from .generate import utils
from .resnet import ResNet
from .utils.utils import DropPath
from .vit import vit_base, vit_huge, vit_large, vit_large_336

logger = logging.get_logger(__name__)

@@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList):
yield m


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
x (`nn.Modules`): input nn layers.
drop_prob (`float`): drop path ratio.
training (`bool`): whether is training or inference.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (1, x.shape[1], 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
drop_prob: drop path ratio.
"""

def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)


class OFAAttention(nn.Module):
r"""
Multi-headed attention, with additional implementation for NormFormer.
@@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel):
self.padding_idx)

if config.add_type_embedding:
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
if config.use_image_feature:
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
else:
self.type_embedding = Embedding(1, embed_dim, padding_idx=None)
else:
self.type_embedding = None

if config.resnet_type == 'resnet18':
self.embed_images = ResNet(
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet34':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet50':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet101':
self.embed_images = ResNet(
[3, 4, 23], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet152':
self.embed_images = ResNet(
[3, 8, 36], drop_path_rate=config.resnet_drop_path_rate)
else:
raise NotImplementedError
if config.use_image_feature:
if config.use_ofasys:
vit_backbone = {
'vit_base': vit_base,
'vit_large': vit_large,
'vit_large_336': vit_large_336,
'vit_huge': vit_huge,
}[config.vit_type]
self.embed_images = vit_backbone(config.vit_drop_path_rate)

self.image_proj = Linear(1024, embed_dim)
self.image_proj = Linear(self.embed_images.width, embed_dim)

if config.resnet_model_path:
else:
if config.resnet_type == 'resnet18':
self.embed_images = ResNet(
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet34':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet50':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet101':
self.embed_images = ResNet(
[3, 4, 23],
drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet152':
self.embed_images = ResNet(
[3, 8, 36],
drop_path_rate=config.resnet_drop_path_rate)
else:
raise NotImplementedError

self.image_proj = Linear(1024, embed_dim)

if not config.use_ofasys and config.resnet_model_path:
print('load resnet {}'.format(config.resnet_model_path))
resnet_state_dict = torch.load(config.resnet_model_path)
self.embed_images.load_state_dict(resnet_state_dict)
@@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel):

self.embed_positions = Embedding(self.max_source_positions + 2,
embed_dim)
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)

if config.use_image_feature:
self.embed_image_positions = Embedding(
config.image_bucket_size**2 + 1, embed_dim)
if not config.use_ofasys:
self.pos_ln = LayerNorm(embed_dim)

if config.use_image_feature:
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / self.num_attention_heads
* config.attn_scale_factor)**-0.5
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)

if not (config.use_ofasys and config.entangle_position_embedding):
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)

if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
@@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel):
self.token_bucket_size = config.token_bucket_size
token_num_rel_dis = 2 * config.token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
self.share_attn_bias = config.share_attn_bias
num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers
self.token_rel_pos_table_list = nn.ModuleList([
Embedding(
token_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.encoder_layers)
for _ in range(num_rel_pos_tables)
])

self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.encoder_layers)
])
if config.use_image_feature:
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(
config.image_bucket_size, image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis,
self.num_attention_heads,
zero_init=True) for _ in range(num_rel_pos_tables)
])

self.register_buffer('image_rp_bucket', image_rp_bucket)

if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
@@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel):
self.layernorm_embedding = None

self.register_buffer('token_rp_bucket', token_rp_bucket)
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.entangle_position_embedding = config.entangle_position_embedding

self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
self.use_ofasys = config.use_ofasys

def get_input_embeddings(self):
r"""
@@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel):
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))

pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
if self.use_ofasys:
if patch_images is not None:
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
else:
pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)

def build_abs_pos_bias(pos_embed):
batch_size, seq_length = pos_embed.size(0), pos_embed.size(1)
if not (self.use_ofasys and self.entangle_position_embedding):
pos_q = self.pos_q_linear(pos_embed).view(
batch_size, seq_length, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
batch_size, seq_length, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
seq_length,
seq_length,
dtype=pos_embed.dtype,
device=pos_embed.device)
return abs_pos_bias

pos_q = self.pos_q_linear(pos_embed).view(
x.size(0), x.size(1), self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
x.size(0), x.size(1), self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
abs_pos_bias = build_abs_pos_bias(pos_embed)

# expand attention_mask
if has_pads:
@@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel):
if output_hidden_states:
encoder_states += (x, )
self_attn_bias = abs_pos_bias.clone()

real_idx = 0 if self.share_attn_bias else idx

self_attn_bias[:, :, -input_ids.size(1):,
-input_ids.size(1):] += self.get_rel_pos_bias(
input_ids, idx)
input_ids, real_idx)
if patch_images_2 is not None:
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
self.get_image_rel_pos_bias(image_position_ids_2, idx)
self.get_image_rel_pos_bias(image_position_ids_2, real_idx)
self_attn_bias[:, :,
image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa
image_num_patches_2:image_num_patches_2 + image_num_patches] += \
self.get_image_rel_pos_bias(image_position_ids, idx) # noqa
self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa
elif patch_images is not None:
self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
self.get_image_rel_pos_bias(image_position_ids, real_idx)
self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1))

hidden_outputs = layer(
@@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel):
self._future_mask = torch.empty(0)
self.share_input_output_embed = config.share_decoder_input_output_embed
self.num_attention_heads = config.decoder_attention_heads
self.use_ofasys = config.use_ofasys
self.disable_entangle = config.disable_entangle

if embed_tokens is not None:
self.embed_tokens = embed_tokens
@@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel):
else:
self.layernorm_embedding = None

if config.use_ofasys:
if config.add_type_embedding:
self.type_embedding = Embedding(
1, self.embed_dim, padding_idx=None)
else:
self.type_embedding = None

self.window_size = config.code_image_size // 8

self.embed_positions = Embedding(self.max_target_positions + 2,
self.embed_dim)
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
self.embed_dim)
self.pos_ln = LayerNorm(self.embed_dim)
self.image_pos_ln = LayerNorm(self.embed_dim)

if not config.use_ofasys:
self.embed_image_positions = Embedding(
config.image_bucket_size**2 + 1, self.embed_dim)
if not config.use_ofasys:
self.pos_ln = LayerNorm(self.embed_dim)
self.image_pos_ln = LayerNorm(self.embed_dim)
self.pos_scaling = float(self.embed_dim / self.num_attention_heads
* config.attn_scale_factor)**-0.5
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)

if not (config.use_ofasys and config.entangle_position_embedding):
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)

self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)

@@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel):
self.token_bucket_size = config.token_bucket_size
token_num_rel_dis = 2 * config.token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)

self.share_attn_bias = config.share_attn_bias
num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers
self.token_rel_pos_table_list = nn.ModuleList([
Embedding(
token_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.decoder_layers)
for _ in range(num_rel_pos_tables)
])

self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
image_position_idx = torch.cat(
[torch.tensor([0]), image_position_idx.view(-1)])
image_position_idx = torch.cat(
[image_position_idx,
torch.tensor([1024] * 768)])
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.decoder_layers)
])
if config.use_image_feature:
if not config.use_ofasys:
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size - 1) * (
2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(
config.image_bucket_size, image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
image_position_idx = torch.cat(
[torch.tensor([0]),
image_position_idx.view(-1)])
image_position_idx = torch.cat(
[image_position_idx,
torch.tensor([1024] * 768)])
self.register_buffer('image_position_idx', image_position_idx)

self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis,
self.num_attention_heads,
zero_init=True) for _ in range(num_rel_pos_tables)
])
self.register_buffer('image_rp_bucket', image_rp_bucket)

self.register_buffer('token_rp_bucket', token_rp_bucket)
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.register_buffer('image_position_idx', image_position_idx)
self.entangle_position_embedding = config.entangle_position_embedding

self.gradient_checkpointing = False
@@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel):

batch_size = tgt_pos_embed.size(0)
tgt_len = tgt_pos_embed.size(1)
tgt_pos_embed = self.image_pos_ln(
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
if not self.use_ofasys:
tgt_pos_embed = self.image_pos_ln(
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)

if src_pos_embed is not None:
src_len = src_pos_embed.size(1)
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
if not (self.entangle_position_embedding and self.use_ofasys):
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
tgt_len,
src_len,
dtype=tgt_pos_embed.dtype,
device=tgt_pos_embed.device)
else:
src_len = tgt_pos_embed.size(1)
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
# batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1)
if not (self.entangle_position_embedding and self.use_ofasys):
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
tgt_len,
tgt_len,
dtype=tgt_pos_embed.dtype,
device=tgt_pos_embed.device)

return abs_pos_bias

@@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel):
past_key_values) > 0 else None

self_attn_bias = self_abs_pos_bias.clone()
real_idx = 0 if self.share_attn_bias else idx
if code_masks is None or not code_masks.any():
self_attn_bias += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
elif code_masks is not None and code_masks.all():
self_attn_bias += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
else:
self_attn_bias[~code_masks] += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
self_attn_bias = self_attn_bias.reshape(
-1,
*self_attn_bias.size()[-2:])
@@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel):

self.encoder = OFAEncoder(config, shared)
self.decoder = OFADecoder(config, shared)
self.use_ofasys = config.use_ofasys

# Initialize weights and apply final processing
self.post_init()


+ 40
- 0
modelscope/models/multi_modal/ofa/utils/utils.py View File

@@ -2,6 +2,7 @@
from typing import Optional

import torch
import torch.nn as nn


def expand_mask(mask: torch.Tensor,
@@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor,
src_len).to(dtype)
return expanded_mask.masked_fill(expanded_mask.bool(),
torch.finfo(dtype).min)


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
x (`nn.Modules`): input nn layers.
drop_prob (`float`): drop path ratio.
training (`bool`): whether is training or inference.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (1, x.shape[1], 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
drop_prob: drop path ratio.
"""

def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)

+ 155
- 0
modelscope/models/multi_modal/ofa/vit.py View File

@@ -0,0 +1,155 @@
from collections import OrderedDict

import torch
import torch.nn.functional as F
from fairseq.modules import LayerNorm
from torch import nn

from .utils.utils import DropPath

__all__ = [
'vit_base',
'vit_large',
'vit_large_336',
'vit_huge',
]


class QuickGELU(nn.Module):

def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):

def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
drop_path_rate=0.0):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([
('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model)),
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.drop_path = DropPath(drop_path_rate)

def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None else None)
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.drop_path(self.attention(self.ln_1(x)))
x = x + self.drop_path(self.mlp(self.ln_2(x)))
return x


class Transformer(nn.Module):

def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
drop_path_rate: float = 0.0,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate)
for _ in range(layers)
])

def forward(self, x: torch.Tensor):
return self.resblocks(x)


class VisionTransformer(nn.Module):

def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
drop_path_rate: float = 0.0,
):
super().__init__()
self.input_resolution = input_resolution
self.patch_size = patch_size
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)

scale = width**-0.5
self.width = width
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(
width, layers, heads, drop_path_rate=drop_path_rate)

def forward(self, x: torch.Tensor):
resolution = x.shape[-2]
height, width = x.shape[-2] // self.patch_size, x.shape[
-1] // self.patch_size
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]

if resolution != self.input_resolution:
old_pe = self.positional_embedding[1:]
patch_num = self.input_resolution // self.patch_size
old_pe = old_pe.reshape(1, patch_num, patch_num,
-1).permute(0, 3, 1, 2)
new_pe = F.interpolate(
old_pe, size=(height, width), mode='bilinear')
new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1)
x = x + new_pe.to(x.dtype)
else:
x = x + self.positional_embedding[1:].to(x.dtype)
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

bz, seq, hidden = x.shape
x = x.transpose(1, 2).reshape(bz, hidden, height, width)

return x


def vit_base(drop_path_rate: float = 0.0):
return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate)


def vit_large(drop_path_rate: float = 0.0):
return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate)


def vit_large_336(drop_path_rate: float = 0.0):
return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate)


def vit_huge(drop_path_rate: float = 0.0):
return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate)

+ 22
- 3
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import re
import string
from functools import partial
from os import path as osp
@@ -53,8 +54,11 @@ class OfaForAllTasks(TorchModel):
raise NotImplementedError
# there is some diff between here and our ofa code,
# there will be no need to use param: use_bpe
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
if not model.use_ofasys:
self.tokenizer.add_tokens(
['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(
['<bin_{}>'.format(i) for i in range(1000)])
self.cfg.update({'num_bins': 1000, 'num_codes': 8192})
self.batch_size = self.cfg.model.get('batch_size', 1)
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
@@ -107,6 +111,8 @@ class OfaForAllTasks(TorchModel):
Tasks.text_classification: inference_d[self.gen_type],
Tasks.image_classification: inference_d[self.gen_type],
}
pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))'
self.pattern = re.compile(pattern_str)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input = move_to_device(input, self.model.device)
@@ -132,8 +138,18 @@ class OfaForAllTasks(TorchModel):
caption = input[OutputKeys.CAPTION]
result_l = list()
for cap in caption:
result_l.append(cap.translate(self.transtab).strip())
if self.language == 'en':
result_l.append(cap.translate(self.transtab).strip())
else:
result_l.append(cap)
input[OutputKeys.CAPTION] = result_l
if self.gen_type == 'generation' and self.language in [
'zh', 'cn'
] and self.cfg.task != Tasks.visual_grounding:
ret_l = list()
for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]:
ret_l.append(self.detokenizer(text))
input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l
return input

def _text_gen_inference(self, input):
@@ -311,3 +327,6 @@ class OfaForAllTasks(TorchModel):
save_function=partial(save_function, with_meta=False),
config=config,
**kwargs)

def detokenizer(self, text):
return self.pattern.sub('', text)

+ 3
- 0
modelscope/models/science/unifold/dataset.py View File

@@ -1,3 +1,6 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import copy
import logging
import os


+ 3
- 0
modelscope/models/science/unifold/model.py View File

@@ -1,3 +1,6 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import argparse
import os
from typing import Any


+ 3
- 0
modelscope/models/science/unifold/modules/__init__.py View File

@@ -0,0 +1,3 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
"""Unifold Modules."""

+ 2
- 0
modelscope/msdatasets/ms_dataset.py View File

@@ -274,6 +274,8 @@ class MsDataset:
try:
api.on_dataset_download(
dataset_name=download_dataset, namespace=namespace)
api.dataset_download_uv(
dataset_name=download_dataset, namespace=namespace)
except Exception as e:
logger.error(e)



+ 19
- 15
modelscope/outputs/outputs.py View File

@@ -69,11 +69,23 @@ TASK_OUTPUTS = {
# face 2d keypoint result for single sample
# {
# "keypoints": [
# [x1, y1]*106
# [[x, y]*106],
# [[x, y]*106],
# [[x, y]*106],
# ],
# "poses": [pitch, roll, yaw]
# "poses": [
# [pitch, roll, yaw],
# [pitch, roll, yaw],
# [pitch, roll, yaw],
# ],
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ]
# }
Tasks.face_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.POSES],
Tasks.face_2d_keypoints:
[OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES],

# face detection result for single sample
# {
@@ -479,17 +491,8 @@ TASK_OUTPUTS = {
# word segmentation result for single sample
# {
# "output": "今天 天气 不错 , 适合 出去 游玩"
# "labels": [
# {'word': '今天', 'label': 'PROPN'},
# {'word': '天气', 'label': 'PROPN'},
# {'word': '不错', 'label': 'VERB'},
# {'word': ',', 'label': 'NUM'},
# {'word': '适合', 'label': 'NOUN'},
# {'word': '出去', 'label': 'PART'},
# {'word': '游玩', 'label': 'ADV'},
# ]
# }
Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS],
Tasks.word_segmentation: [OutputKeys.OUTPUT],

# TODO @wenmeng.zwm support list of result check
# named entity recognition result for single sample
@@ -699,8 +702,9 @@ TASK_OUTPUTS = {
# "text_embedding": np.array with shape [1, D],
# "caption": "this is an image caption text."
# }
Tasks.generative_multi_modal_embedding:
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION],
Tasks.generative_multi_modal_embedding: [
OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION
],

# multi-modal similarity result for single sample
# {


+ 4
- 1
modelscope/pipelines/base.py View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, List, Mapping, Union

import numpy as np

from modelscope.hub.utils.utils import create_library_statistics
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.outputs import TASK_OUTPUTS
@@ -151,7 +152,9 @@ class Pipeline(ABC):
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function

for single_model in self.models:
if hasattr(single_model, 'name'):
create_library_statistics('pipeline', single_model.name, None)
# place model to cpu or gpu
if (self.model or (self.has_multiple_models and self.models[0])):
if not self._model_prepare:


+ 2
- 3
modelscope/pipelines/builder.py View File

@@ -93,9 +93,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_resnet50_live-category'),
Tasks.video_category: (Pipelines.video_category,
'damo/cv_resnet50_video-category'),
Tasks.multi_modal_embedding:
(Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-large-patch14_zh'),
Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-base-patch16_zh'),
Tasks.generative_multi_modal_embedding:
(Pipelines.generative_multi_modal_embedding,
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'


+ 3
- 3
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -132,8 +132,8 @@ class Body3DKeypointsPipeline(Pipeline):
device='gpu' if torch.cuda.is_available() else 'cpu')

def preprocess(self, input: Input) -> Dict[str, Any]:
video_url = input
video_frames = self.read_video_frames(video_url)
self.video_url = input
video_frames = self.read_video_frames(self.video_url)
if 0 == len(video_frames):
res = {'success': False, 'msg': 'get video frame failed.'}
return res
@@ -198,7 +198,7 @@ class Body3DKeypointsPipeline(Pipeline):
}

if not input['success']:
pass
res[OutputKeys.OUTPUT_VIDEO] = self.video_url
else:
poses = input[KeypointsTypes.POSES_CAMERA]
pred_3d_pose = poses.data.cpu().numpy()[


+ 205
- 7
modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py View File

@@ -1,12 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
from typing import Any

import cv2
import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .base import EasyCVPipeline

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_2d_keypoints, module_name=Pipelines.face_2d_keypoints)
@@ -29,18 +39,206 @@ class Face2DKeypointsPipeline(EasyCVPipeline):
*args,
**kwargs)

# face detect pipeline
det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'
self.face_detection = pipeline(
Tasks.face_detection, model=det_model_id)

def show_result(self, img, points, scale=2, save_path=None):
return self.predict_op.show_result(img, points, scale, save_path)

def _choose_face(self, det_result, min_face=10):
"""
choose face with maximum area
Args:
det_result: output of face detection pipeline
min_face: minimum size of valid face w/h
"""
bboxes = np.array(det_result[OutputKeys.BOXES])
landmarks = np.array(det_result[OutputKeys.KEYPOINTS])
if bboxes.shape[0] == 0:
logger.warn('No face detected!')
return None
# face idx with enough size
face_idx = []
for i in range(bboxes.shape[0]):
box = bboxes[i]
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face:
face_idx += [i]
if len(face_idx) == 0:
logger.warn(
f'Face size not enough, less than {min_face}x{min_face}!')
return None
bboxes = bboxes[face_idx]
landmarks = landmarks[face_idx]

return bboxes, landmarks

def expend_box(self, box, w, h, scalex=0.3, scaley=0.5):
x1 = box[0]
y1 = box[1]
wb = box[2] - x1
hb = box[3] - y1
deltax = int(wb * scalex)
deltay1 = int(hb * scaley)
deltay2 = int(hb * scalex)
x1 = x1 - deltax
y1 = y1 - deltay1
if x1 < 0:
deltax = deltax + x1
x1 = 0
if y1 < 0:
deltay1 = deltay1 + y1
y1 = 0
x2 = x1 + wb + 2 * deltax
y2 = y1 + hb + deltay1 + deltay2
x2 = np.clip(x2, 0, w - 1)
y2 = np.clip(y2, 0, h - 1)
return [x1, y1, x2, y2]

def rotate_point(self, angle, center, landmark):
rad = angle * np.pi / 180.0
alpha = np.cos(rad)
beta = np.sin(rad)
M = np.zeros((2, 3), dtype=np.float32)
M[0, 0] = alpha
M[0, 1] = beta
M[0, 2] = (1 - alpha) * center[0] - beta * center[1]
M[1, 0] = -beta
M[1, 1] = alpha
M[1, 2] = beta * center[0] + (1 - alpha) * center[1]

landmark_ = np.asarray([(M[0, 0] * x + M[0, 1] * y + M[0, 2],
M[1, 0] * x + M[1, 1] * y + M[1, 2])
for (x, y) in landmark])
return M, landmark_

def rotate_crop_img(self, img, pts, M):
imgT = cv2.warpAffine(img, M, (int(img.shape[1]), int(img.shape[0])))

x1 = pts[5][0]
x2 = pts[5][0]
y1 = pts[5][1]
y2 = pts[5][1]
for i in range(0, 9):
x1 = min(x1, pts[i][0])
x2 = max(x2, pts[i][0])
y1 = min(y1, pts[i][1])
y2 = max(y2, pts[i][1])

height, width, _ = imgT.shape
x1 = min(max(0, int(x1)), width)
y1 = min(max(0, int(y1)), height)
x2 = min(max(0, int(x2)), width)
y2 = min(max(0, int(y2)), height)
sub_imgT = imgT[y1:y2, x1:x2]

return sub_imgT, imgT, [x1, y1, x2, y2]

def crop_img(self, imgT, pts):
enlarge_ratio = 1.1

x1 = np.min(pts[:, 0])
x2 = np.max(pts[:, 0])
y1 = np.min(pts[:, 1])
y2 = np.max(pts[:, 1])
w = x2 - x1 + 1
h = y2 - y1 + 1
x1 = int(x1 - (enlarge_ratio - 1.0) / 2.0 * w)
y1 = int(y1 - (enlarge_ratio - 1.0) / 2.0 * h)
x1 = max(0, x1)
y1 = max(0, y1)

new_w = int(enlarge_ratio * w)
new_h = int(enlarge_ratio * h)
new_x1 = x1
new_y1 = y1
new_x2 = new_x1 + new_w
new_y2 = new_y1 + new_h

height, width, _ = imgT.shape

new_x1 = min(max(0, new_x1), width)
new_y1 = min(max(0, new_y1), height)
new_x2 = max(min(width, new_x2), 0)
new_y2 = max(min(height, new_y2), 0)

sub_imgT = imgT[new_y1:new_y2, new_x1:new_x2]

return sub_imgT, [new_x1, new_y1, new_x2, new_y2]

def __call__(self, inputs) -> Any:
outputs = self.predict_op(inputs)
img = LoadImage.convert_to_ndarray(inputs)
h, w, c = img.shape
img_rgb = copy.deepcopy(img)
img_rgb = img_rgb[:, :, ::-1]
det_result = self.face_detection(img_rgb)

bboxes = np.array(det_result[OutputKeys.BOXES])
if bboxes.shape[0] == 0:
logger.warn('No face detected!')
results = {
OutputKeys.KEYPOINTS: [],
OutputKeys.POSES: [],
OutputKeys.BOXES: []
}
return results

boxes, keypoints = self._choose_face(det_result)

output_boxes = []
output_keypoints = []
output_poses = []
for index, box_ori in enumerate(boxes):
box = self.expend_box(box_ori, w, h, scalex=0.1, scaley=0.1)
y0 = int(box[1])
y1 = int(box[3])
x0 = int(box[0])
x1 = int(box[2])
sub_img = img[y0:y1, x0:x1]

keypoint = keypoints[index]
pts = [[keypoint[0], keypoint[1]], [keypoint[2], keypoint[3]],
[keypoint[4], keypoint[5]], [keypoint[6], keypoint[7]],
[keypoint[8], keypoint[9]], [box[0], box[1]],
[box[2], box[1]], [box[0], box[3]], [box[2], box[3]]]
# radian
angle = math.atan2((pts[1][1] - pts[0][1]),
(pts[1][0] - pts[0][0]))
# angle
theta = angle * (180 / np.pi)

center = [w // 2, h // 2]
cx, cy = center
M, landmark_ = self.rotate_point(theta, (cx, cy), pts)
sub_imgT, imgT, bbox = self.rotate_crop_img(img, landmark_, M)

outputs = self.predict_op([sub_imgT])[0]
tmp_keypoints = outputs['point']

for idx in range(0, len(tmp_keypoints)):
tmp_keypoints[idx][0] += bbox[0]
tmp_keypoints[idx][1] += bbox[1]

for idx in range(0, 6):
sub_img, bbox = self.crop_img(imgT, tmp_keypoints)
outputs = self.predict_op([sub_img])[0]
tmp_keypoints = outputs['point']
for idx in range(0, len(tmp_keypoints)):
tmp_keypoints[idx][0] += bbox[0]
tmp_keypoints[idx][1] += bbox[1]

M2, tmp_keypoints = self.rotate_point(-theta, (cx, cy),
tmp_keypoints)

results = [{
OutputKeys.KEYPOINTS: output['point'],
OutputKeys.POSES: output['pose']
} for output in outputs]
output_keypoints.append(np.array(tmp_keypoints))
output_poses.append(np.array(outputs['pose']))
output_boxes.append(np.array(box_ori))

if self._is_single_inputs(inputs):
results = results[0]
results = {
OutputKeys.KEYPOINTS: output_keypoints,
OutputKeys.POSES: output_poses,
OutputKeys.BOXES: output_boxes
}

return results

+ 2
- 2
modelscope/pipelines/nlp/token_classification_pipeline.py View File

@@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline):
chunk['span'] = text[chunk['start']:chunk['end']]
chunks.append(chunk)

# for cws output
# for cws outputs
if len(chunks) > 0 and chunks[0]['type'] == 'cws':
spans = [
chunk['span'] for chunk in chunks if chunk['span'].strip()
]
seg_result = ' '.join(spans)
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []}
outputs = {OutputKeys.OUTPUT: seg_result}

# for ner outputs
else:


+ 3
- 3
modelscope/pipelines/nlp/word_segmentation_pipeline.py View File

@@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline):
chunk['span'] = text[chunk['start']:chunk['end']]
chunks.append(chunk)

# for cws output
# for cws outputs
if len(chunks) > 0 and chunks[0]['type'] == 'cws':
spans = [
chunk['span'] for chunk in chunks if chunk['span'].strip()
]
seg_result = ' '.join(spans)
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []}
outputs = {OutputKeys.OUTPUT: seg_result}

# for ner outpus
# for ner output
else:
outputs = {OutputKeys.OUTPUT: chunks}
return outputs

+ 1
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor):
data[key] = item
return data

def _ofa_input_compatibility_conversion(self, data):
def _ofa_input_compatibility_conversion(self, data): # fake
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
if isinstance(data['image'], str):
image = load_image(data['image'])


+ 12
- 5
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC):
label=None,
label2id=None,
mode=ModeKeys.INFERENCE,
use_fast=None,
**kwargs):
"""The NLP preprocessor base class.

@@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC):
label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping
if this mapping is not supplied.
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode
use_fast: use the fast version of tokenizer

"""
self.model_dir = model_dir
self.first_sequence = first_sequence
self.second_sequence = second_sequence
self.label = label

self.use_fast = kwargs.pop('use_fast', None)
if self.use_fast is None and os.path.isfile(
self.use_fast = use_fast
if self.use_fast is None and model_dir is None:
self.use_fast = False
elif self.use_fast is None and os.path.isfile(
os.path.join(model_dir, 'tokenizer_config.json')):
with open(os.path.join(model_dir, 'tokenizer_config.json'),
'r') as f:
@@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC):
self.use_fast = False if self.use_fast is None else self.use_fast

self.label2id = label2id
if self.label2id is None:
self.label2id = parse_label_mapping(self.model_dir)
if self.label2id is None and model_dir is not None:
self.label2id = parse_label_mapping(model_dir)
super().__init__(mode, **kwargs)

@property
@@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor):
label: str = 'label',
label2id: dict = None,
mode: str = ModeKeys.INFERENCE,
use_fast: bool = None,
**kwargs):
"""The NLP tokenizer preprocessor base class.

@@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor):
- config.json label2id/id2label
- label_mapping.json
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different.
use_fast: use the fast version of tokenizer
kwargs: These kwargs will be directly fed into the tokenizer.
"""

super().__init__(model_dir, first_sequence, second_sequence, label,
label2id, mode)
label2id, mode, use_fast, **kwargs)
self.model_dir = model_dir
self.tokenize_kwargs = kwargs
self.tokenizer = self.build_tokenizer(model_dir)


+ 82
- 69
modelscope/preprocessors/nlp/token_classification_preprocessor.py View File

@@ -2,6 +2,7 @@

from typing import Any, Dict, Tuple, Union

import numpy as np
import torch

from modelscope.metainfo import Preprocessors
@@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor):
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.first_sequence: str = kwargs.pop('first_sequence', 'tokens')
self.label = kwargs.pop('label', OutputKeys.LABELS)

def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]:
@@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
'is_split_into_words', False)
if 'label2id' in kwargs:
kwargs.pop('label2id')
self.tokenize_kwargs = kwargs

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
@type_assert(object, (str, dict))
def __call__(self, data: Union[dict, str]) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
text = None
labels_list = None
if isinstance(data, str):
# for inference inputs without label
text = data
self.tokenize_kwargs['add_special_tokens'] = False
elif isinstance(data, dict):
# for finetune inputs with label
text = data.get(self.first_sequence)
labels_list = data.get(self.label)
if isinstance(text, list):
self.tokenize_kwargs['is_split_into_words'] = True

input_ids = []
label_mask = []
offset_mapping = []
if self.is_split_into_words:
for offset, token in enumerate(list(data)):
subtoken_ids = self.tokenizer.encode(
token, add_special_tokens=False)
token_type_ids = []
if self.is_split_into_words and self._mode == ModeKeys.INFERENCE:
for offset, token in enumerate(list(text)):
subtoken_ids = self.tokenizer.encode(token,
**self.tokenize_kwargs)
if len(subtoken_ids) == 0:
subtoken_ids = [self.tokenizer.unk_token_id]
input_ids.extend(subtoken_ids)
@@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
else:
if self.tokenizer.is_fast:
encodings = self.tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
**self.tokenize_kwargs)
text, return_offsets_mapping=True, **self.tokenize_kwargs)
attention_mask = encodings['attention_mask']
token_type_ids = encodings['token_type_ids']
input_ids = encodings['input_ids']
word_ids = encodings.word_ids()
for i in range(len(word_ids)):
@@ -137,75 +140,85 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])
else:
encodings = self.tokenizer(
text, add_special_tokens=False, **self.tokenize_kwargs)
encodings = self.tokenizer(text, **self.tokenize_kwargs)
input_ids = encodings['input_ids']
label_mask, offset_mapping = self.get_label_mask_and_offset_mapping(
text)

if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
offset_mapping = offset_mapping[:sum(label_mask)]
if self._mode == ModeKeys.INFERENCE:
if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
offset_mapping = offset_mapping[:sum(label_mask)]

if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]
if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]

if self._mode == ModeKeys.INFERENCE:
input_ids = torch.tensor(input_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)
label_mask = torch.tensor(
label_mask, dtype=torch.bool).unsqueeze(0)

# the token classification
output = {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}

# align the labels with tokenized text
if labels_list is not None:
assert self.label2id is not None
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
label_enumerate_values = [
k for k, v in sorted(
self.label2id.items(), key=lambda item: item[1])
]
for idx, label in enumerate(label_enumerate_values):
if label.startswith('B-') and label.replace(
'B-', 'I-') in label_enumerate_values:
b_to_i_label.append(
label_enumerate_values.index(
label.replace('B-', 'I-')))
else:
b_to_i_label.append(idx)
# the token classification
output = {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}
else:
output = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
}

label_row = [self.label2id[lb] for lb in labels_list]
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
label_ids.append(label_row[word_idx])
else:
if self.label_all_tokens:
label_ids.append(b_to_i_label[label_row[word_idx]])
# align the labels with tokenized text
if labels_list is not None:
assert self.label2id is not None
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
label_enumerate_values = [
k for k, v in sorted(
self.label2id.items(), key=lambda item: item[1])
]
for idx, label in enumerate(label_enumerate_values):
if label.startswith('B-') and label.replace(
'B-', 'I-') in label_enumerate_values:
b_to_i_label.append(
label_enumerate_values.index(
label.replace('B-', 'I-')))
else:
b_to_i_label.append(idx)

label_row = [self.label2id[lb] for lb in labels_list]
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
previous_word_idx = word_idx
labels = label_ids
output['labels'] = labels
elif word_idx != previous_word_idx:
label_ids.append(label_row[word_idx])
else:
if self.label_all_tokens:
label_ids.append(b_to_i_label[label_row[word_idx]])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels = label_ids
output['labels'] = labels
output = {
k: np.array(v) if isinstance(v, list) else v
for k, v in output.items()
}
return output

def get_tokenizer_class(self):


+ 2
- 2
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -74,8 +74,8 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
self.patch_resize_transform = transforms.Compose([
lambda image: ocr_resize(
image,
self.cfg.model.patch_image_size,
is_document=self.cfg.model.is_document),
self.patch_image_size,
is_document=self.cfg.model.get('is_document', False)),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])


+ 5
- 2
modelscope/trainers/audio/kws_farfield_trainer.py View File

@@ -69,11 +69,14 @@ class KWSFarfieldTrainer(BaseTrainer):

super().__init__(cfg_file, arg_parse_fn)

self.model = self.build_model()
self.work_dir = work_dir
# the number of model output dimension
# should update config outside the trainer, if user need more wake word
num_syn = kwargs.get('num_syn', None)
if num_syn:
self.cfg.model.num_syn = num_syn
self._num_classes = self.cfg.model.num_syn
self.model = self.build_model()
self.work_dir = work_dir

if kwargs.get('launcher', None) is not None:
init_dist(kwargs['launcher'])


+ 11
- 11
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):

def __init__(self, args):
super().__init__()
self.sentence_avg = args.sentence_avg
self.eps = args.label_smoothing
self.ignore_prefix_size = args.ignore_prefix_size
self.ignore_eos = args.ignore_eos
self.report_accuracy = args.report_accuracy
self.drop_worst_ratio = args.drop_worst_ratio
self.drop_worst_after = args.drop_worst_after
self.use_rdrop = args.use_rdrop
self.reg_alpha = args.reg_alpha
self.sample_patch_num = args.sample_patch_num
self.sentence_avg = args.get('sentence_avg', False)
self.eps = args.get('label_smoothing', 0.1)
self.ignore_prefix_size = args.get('ignore_prefix_size', 0)
self.ignore_eos = args.get('ignore_eos', False)
self.report_accuracy = args.get('report_accuracy', False)
self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0)
self.drop_worst_after = args.get('drop_worst_after', 0)
self.use_rdrop = args.get('use_rdrop', False)
self.reg_alpha = args.get('reg_alpha', 1.0)
self.sample_patch_num = args.get('sample_patch_num', 196)

self.constraint_start = None
self.constraint_end = None
if args.constraint_range:
if args.get('constraint_range', None):
constraint_start, constraint_end = args.constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)


+ 1
- 1
modelscope/trainers/nlp/text_generation_trainer.py View File

@@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer):
return tokenizer.decode(tokens.tolist(), skip_special_tokens=True)

def evaluation_step(self, data):
model = self.model
model = self.model.module if self._dist else self.model
model.eval()

with torch.no_grad():


+ 4
- 2
modelscope/trainers/nlp_trainer.py View File

@@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer):
preprocessor_mode=ModeKeys.TRAIN,
**model_args,
**self.train_keys,
mode=ModeKeys.TRAIN)
mode=ModeKeys.TRAIN,
use_fast=True)
eval_preprocessor = Preprocessor.from_pretrained(
self.model_dir,
cfg_dict=self.cfg,
preprocessor_mode=ModeKeys.EVAL,
**model_args,
**self.eval_keys,
mode=ModeKeys.EVAL)
mode=ModeKeys.EVAL,
use_fast=True)
return train_preprocessor, eval_preprocessor




+ 6
- 1
modelscope/trainers/trainer.py View File

@@ -15,6 +15,7 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.utils.utils import create_library_statistics
from modelscope.metainfo import Trainers
from modelscope.metrics import build_metric, task_default_metrics
from modelscope.models.base import Model, TorchModel
@@ -436,6 +437,8 @@ class EpochBasedTrainer(BaseTrainer):

def train(self, checkpoint_path=None, *args, **kwargs):
self._mode = ModeKeys.TRAIN
if hasattr(self.model, 'name'):
create_library_statistics('train', self.model.name, None)

if self.train_dataset is None:
self.train_dataloader = self.get_train_dataloader()
@@ -456,6 +459,8 @@ class EpochBasedTrainer(BaseTrainer):
self.train_loop(self.train_dataloader)

def evaluate(self, checkpoint_path=None):
if hasattr(self.model, 'name'):
create_library_statistics('evaluate', self.model.name, None)
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook
CheckpointHook.load_checkpoint(checkpoint_path, self)
@@ -876,7 +881,7 @@ class EpochBasedTrainer(BaseTrainer):
Subclass and override to inject custom behavior.

"""
model = self.model
model = self.model.module if self._dist else self.model
model.eval()

if is_parallel(model):


+ 8
- 0
modelscope/utils/constant.py View File

@@ -238,6 +238,14 @@ class DownloadMode(enum.Enum):
FORCE_REDOWNLOAD = 'force_redownload'


class DownloadChannel(enum.Enum):
""" Channels of datasets downloading for uv/pv counting.
"""
LOCAL = 'local'
DSW = 'dsw'
EAIS = 'eais'


class UploadMode(enum.Enum):
""" How to upload object to remote.
"""


+ 65
- 0
modelscope/utils/cv/image_utils.py View File

@@ -91,6 +91,71 @@ def draw_keypoints(output, original_image):
return image


def draw_106face_keypoints(in_path,
keypoints,
boxes,
scale=4.0,
save_path=None):
face_contour_point_index = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
]
left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33]
right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42]
left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66]
right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75]
nose_bridge_point_index = [51, 52, 53, 54]
nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
mouth_outer_point_index = [
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84
]
mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96]

img = cv2.imread(in_path)

for i in range(len(boxes)):
draw_box(img, np.array(boxes[i]))

image = cv2.resize(img, dsize=None, fx=scale, fy=scale)

def draw_line(point_index, image, point):
for i in range(len(point_index) - 1):
cur_index = point_index[i]
next_index = point_index[i + 1]
cur_pt = (int(point[cur_index][0] * scale),
int(point[cur_index][1] * scale))
next_pt = (int(point[next_index][0] * scale),
int(point[next_index][1] * scale))
cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2)

for i in range(len(keypoints)):
points = keypoints[i]

draw_line(face_contour_point_index, image, points)
draw_line(left_eye_brow_point_index, image, points)
draw_line(right_eye_brow_point_index, image, points)
draw_line(left_eye_point_index, image, points)
draw_line(right_eye_point_index, image, points)
draw_line(nose_bridge_point_index, image, points)
draw_line(nose_contour_point_index, image, points)
draw_line(mouth_outer_point_index, image, points)
draw_line(mouth_inter_point_index, image, points)

size = len(points)
for i in range(size):
x = int(points[i][0])
y = int(points[i][1])
cv2.putText(image, str(i), (int(x * scale), int(y * scale)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0),
cv2.FILLED)

if save_path is not None:
cv2.imwrite(save_path, image)

return image


def draw_face_detection_no_lm_result(img_path, detection_result):
bboxes = np.array(detection_result[OutputKeys.BOXES])
scores = np.array(detection_result[OutputKeys.SCORES])


+ 2
- 1
requirements/framework.txt View File

@@ -1,6 +1,7 @@
addict
attrs
datasets
# version beyond 2.5.2 introduces compatbility issue and is being resolved
datasets<=2.5.2
easydict
einops
filelock>=3.3.0


+ 2
- 0
requirements/multi-modal.txt View File

@@ -2,6 +2,8 @@ ftfy>=6.0.3
ofa>=0.0.2
pycocoevalcap>=1.2
pycocotools>=2.0.4
# compatible with taming-transformers-rom1504
pytorch_lightning<=1.7.7
# rough-score was just recently updated from 0.0.4 to 0.0.7
# which introduced compatability issues that are being investigated
rouge_score<=0.0.4


+ 2
- 0
requirements/science.txt View File

@@ -1,4 +1,6 @@
biopython
iopath
ipdb
lmdb
ml_collections
scipy


+ 6
- 2
tests/msdatasets/test_dataset_upload.py View File

@@ -8,7 +8,8 @@ import zipfile
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects
from modelscope.utils import logger as logging
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode,
ModelFile)
from modelscope.utils.test_utils import test_level

logger = logging.get_logger(__name__)
@@ -104,7 +105,10 @@ class DatasetUploadTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ds_download_dir(self):
test_ds = MsDataset.load(self.dataset_name, self.namespace)
test_ds = MsDataset.load(
self.dataset_name,
namespace=self.namespace,
download_mode=DownloadMode.FORCE_REDOWNLOAD)
assert test_ds.config_kwargs['split_config'].values()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


+ 2
- 1
tests/outputs/test_model_outputs.py View File

@@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase):
self.assertEqual(outputs['logits'], torch.Tensor([1]))
self.assertEqual(outputs[0], torch.Tensor([1]))
self.assertEqual(outputs.logits, torch.Tensor([1]))
outputs.loss = torch.Tensor([2])
logits, loss = outputs
self.assertEqual(logits, torch.Tensor([1]))
self.assertTrue(loss is None)
self.assertTrue(loss is not None)


if __name__ == '__main__':


+ 17
- 12
tests/pipelines/test_face_2d_keypoints.py View File

@@ -1,11 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

import cv2

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_106face_keypoints
from modelscope.utils.test_utils import test_level


@@ -13,7 +12,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_face_2d_keypoints(self):
img_path = 'data/test/images/keypoints_detect/test_img_face_2d_keypoints.png'
img_path = 'data/test/images/face_detection.png'
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment'

face_2d_keypoints_align = pipeline(
@@ -21,15 +20,21 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase):
output = face_2d_keypoints_align(img_path)

output_keypoints = output[OutputKeys.KEYPOINTS]
output_pose = output[OutputKeys.POSES]

img = cv2.imread(img_path)
img = face_2d_keypoints_align.show_result(
img, output_keypoints, scale=2, save_path='face_keypoints.jpg')

self.assertEqual(output_keypoints.shape[0], 106)
self.assertEqual(output_keypoints.shape[1], 2)
self.assertEqual(output_pose.shape[0], 3)
output_poses = output[OutputKeys.POSES]
output_boxes = output[OutputKeys.BOXES]

draw_106face_keypoints(
img_path,
output_keypoints,
output_boxes,
scale=2,
save_path='face_keypoints.jpg')

for idx in range(len(output_keypoints)):
self.assertEqual(output_keypoints[idx].shape[0], 106)
self.assertEqual(output_keypoints[idx].shape[1], 2)
self.assertEqual(output_poses[idx].shape[0], 3)
self.assertEqual(output_boxes[idx].shape[0], 4)


if __name__ == '__main__':


+ 8
- 0
tests/pipelines/test_named_entity_recognition.py View File

@@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.named_entity_recognition
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'

english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom'
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
sentence = '这与温岭市新河镇的一个神秘的传说有关。'
sentence_en = 'pizza shovel'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download(self):
@@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.named_entity_recognition, model=self.lcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_english_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.english_model_id)
print(pipeline_ins(input='pizza shovel'))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.named_entity_recognition)


+ 1
- 1
tests/pipelines/test_unifold.py View File

@@ -19,7 +19,7 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck):
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \
'NIAALKNHIDKIKPIAMQIYKKYSKNIP'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id)
mono_pipeline_ins = pipeline(task=self.task, model=model_dir)


+ 1
- 1
tests/trainers/test_finetune_token_classificatin.py View File

@@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase):
cfg['dataset'] = {
'train': {
'labels': label_enumerate_values,
'first_sequence': 'first_sequence',
'first_sequence': 'tokens',
'label': 'labels',
}
}


Loading…
Cancel
Save