Browse Source

Merge pull request #9 from modelscope/merge_master_internal_1102

post test hang, but all tests have passed
master
Yingda Chen GitHub 2 years ago
parent
commit
42011e48d3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 2050 additions and 740 deletions
  1. +3
    -3
      README.md
  2. +77
    -8
      modelscope/exporters/torch_model_exporter.py
  3. +36
    -7
      modelscope/hub/api.py
  4. +1
    -0
      modelscope/hub/constants.py
  5. +15
    -0
      modelscope/hub/utils/utils.py
  6. +6
    -0
      modelscope/metainfo.py
  7. +1
    -0
      modelscope/metrics/builder.py
  8. +55
    -0
      modelscope/metrics/inbatch_recall_metric.py
  9. +9
    -0
      modelscope/models/base/base_model.py
  10. +38
    -124
      modelscope/models/multi_modal/clip/model.py
  11. +13
    -0
      modelscope/models/multi_modal/ofa/configuration_ofa.py
  12. +203
    -141
      modelscope/models/multi_modal/ofa/modeling_ofa.py
  13. +40
    -0
      modelscope/models/multi_modal/ofa/utils/utils.py
  14. +155
    -0
      modelscope/models/multi_modal/ofa/vit.py
  15. +22
    -3
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  16. +1
    -0
      modelscope/models/nlp/bert/text_ranking.py
  17. +1
    -0
      modelscope/models/nlp/structbert/text_classification.py
  18. +3
    -0
      modelscope/models/science/unifold/dataset.py
  19. +3
    -0
      modelscope/models/science/unifold/model.py
  20. +3
    -0
      modelscope/models/science/unifold/modules/__init__.py
  21. +2
    -0
      modelscope/msdatasets/ms_dataset.py
  22. +19
    -15
      modelscope/outputs/outputs.py
  23. +4
    -1
      modelscope/pipelines/base.py
  24. +2
    -3
      modelscope/pipelines/builder.py
  25. +205
    -7
      modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py
  26. +15
    -7
      modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py
  27. +2
    -2
      modelscope/pipelines/nlp/token_classification_pipeline.py
  28. +3
    -3
      modelscope/pipelines/nlp/word_segmentation_pipeline.py
  29. +178
    -2
      modelscope/preprocessors/multi_modal.py
  30. +12
    -5
      modelscope/preprocessors/nlp/nlp_base.py
  31. +82
    -69
      modelscope/preprocessors/nlp/token_classification_preprocessor.py
  32. +9
    -15
      modelscope/preprocessors/ofa/ocr_recognition.py
  33. +18
    -0
      modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py
  34. +1
    -1
      modelscope/trainers/hooks/logger/text_logger_hook.py
  35. +191
    -154
      modelscope/trainers/multi_modal/clip/clip_trainer.py
  36. +121
    -90
      modelscope/trainers/multi_modal/clip/clip_trainer_utils.py
  37. +11
    -11
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  38. +1
    -1
      modelscope/trainers/nlp/text_generation_trainer.py
  39. +4
    -2
      modelscope/trainers/nlp_trainer.py
  40. +9
    -4
      modelscope/trainers/trainer.py
  41. +8
    -0
      modelscope/utils/constant.py
  42. +65
    -0
      modelscope/utils/cv/image_utils.py
  43. +1
    -16
      modelscope/utils/demo_utils.py
  44. +2
    -13
      modelscope/utils/regress_test_utils.py
  45. +179
    -0
      modelscope/utils/service_utils.py
  46. +1
    -1
      modelscope/version.py
  47. +2
    -1
      requirements/framework.txt
  48. +4
    -0
      requirements/multi-modal.txt
  49. +0
    -1
      requirements/nlp.txt
  50. +2
    -0
      requirements/science.txt
  51. +1
    -1
      tests/export/test_export_sbert_sequence_classification.py
  52. +82
    -2
      tests/hub/test_hub_revision_release_mode.py
  53. +6
    -2
      tests/msdatasets/test_dataset_upload.py
  54. +2
    -1
      tests/outputs/test_model_outputs.py
  55. +17
    -12
      tests/pipelines/test_face_2d_keypoints.py
  56. +2
    -2
      tests/pipelines/test_face_emotion.py
  57. +1
    -1
      tests/pipelines/test_facial_expression_recognition.py
  58. +3
    -3
      tests/pipelines/test_multi_modal_embedding.py
  59. +8
    -0
      tests/pipelines/test_named_entity_recognition.py
  60. +1
    -1
      tests/pipelines/test_unifold.py
  61. +2
    -1
      tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py
  62. +83
    -0
      tests/trainers/test_clip_trainer.py
  63. +1
    -1
      tests/trainers/test_finetune_sequence_classification.py
  64. +1
    -1
      tests/trainers/test_finetune_token_classificatin.py
  65. +1
    -1
      tests/trainers/test_ofa_trainer.py
  66. +1
    -1
      tests/trainers/test_trainer_with_nlp.py

+ 3
- 3
README.md View File

@@ -1,10 +1,10 @@
# Introduction

ModelScope library is targeted to support training, evaluation and inference for the state of the art models provided by Mind and further support third-party models provided by users outside alibaba.
[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bringing together most advanced machine learning models from the AI community, and to streamlining the process of leveraging and applying AI models . The core ModelScope library enables developers to perform model inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains.

In order to enable ModelScope users to use the various models provided by ModelScope quickly and conveniently, we provide a set of complete Python library, which includes the implementation of ModelScope official models, inference, finetuning and evaluation support for those models such as preprocessor and evaluation metrics. We also provide easy-to-use APIs and rich usage examples. By calling the library, users can write just a few lines of code to complete tasks such as model inference, training, and evaluation, and can also quickly carry out secondary development on this basis to realize their own innovative ideas.
The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done within only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary.

At present, the algorithm models provided by library cover four main AI fields of image, natural language processing, speech, and multi-modality, and dozens of application scenarios and tasks.
Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with the backend services of ModelScope, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate various entity (models and datasets) management to be performed seamlessly under-the-hood, such as entity lookup, version control, and cache management.

# Installation



+ 77
- 8
modelscope/exporters/torch_model_exporter.py View File

@@ -7,9 +7,9 @@ from typing import Any, Dict, Mapping
import torch
from torch import nn
from torch.onnx import export as onnx_export
from torch.onnx.utils import _decide_input_format

from modelscope.models import TorchModel
from modelscope.outputs import ModelOutputBase
from modelscope.pipelines.base import collate_fn
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
@@ -102,6 +102,53 @@ class TorchModelExporter(Exporter):
"""
return None

@staticmethod
def _decide_input_format(model, args):
import inspect

def _signature(model) -> inspect.Signature:
should_be_callable = getattr(model, 'forward', model)
if callable(should_be_callable):
return inspect.signature(should_be_callable)
raise ValueError('model has no forward method and is not callable')

try:
sig = _signature(model)
except ValueError as e:
logger.warn('%s, skipping _decide_input_format' % e)
return args
try:
ordered_list_keys = list(sig.parameters.keys())
if ordered_list_keys[0] == 'self':
ordered_list_keys = ordered_list_keys[1:]
args_dict: Dict = {}
if isinstance(args, list):
args_list = args
elif isinstance(args, tuple):
args_list = list(args)
else:
args_list = [args]
if isinstance(args_list[-1], Mapping):
args_dict = args_list[-1]
args_list = args_list[:-1]
n_nonkeyword = len(args_list)
for optional_arg in ordered_list_keys[n_nonkeyword:]:
if optional_arg in args_dict:
args_list.append(args_dict[optional_arg])
# Check if this arg has a default value
else:
param = sig.parameters[optional_arg]
if param.default != param.empty:
args_list.append(param.default)
args = args_list if isinstance(args, list) else tuple(args_list)
# Cases of models with no input args
except IndexError:
logger.warn('No input args, skipping _decide_input_format')
except Exception as e:
logger.warn('Skipping _decide_input_format\n {}'.format(e.args[0]))

return args

def _torch_export_onnx(self,
model: nn.Module,
output: str,
@@ -179,16 +226,21 @@ class TorchModelExporter(Exporter):
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*_decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, Mapping):
outputs_origin = numpify_tensor_nested(
list(outputs_origin.values()))
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = numpify_tensor_nested(outputs_origin)
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs = ort_session.run(
onnx_outputs,
numpify_tensor_nested(dummy_inputs),
)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)

tols = {}
if rtol is not None:
@@ -232,12 +284,25 @@ class TorchModelExporter(Exporter):
'Model property dummy_inputs must be set.')
dummy_inputs = collate_fn(dummy_inputs, device)
if isinstance(dummy_inputs, Mapping):
dummy_inputs = tuple(dummy_inputs.values())
dummy_inputs_filter = []
for _input in self._decide_input_format(model, dummy_inputs):
if _input is not None:
dummy_inputs_filter.append(_input)
else:
break

if len(dummy_inputs) != len(dummy_inputs_filter):
logger.warn(
f'Dummy inputs is not continuous in the forward method, '
f'origin length: {len(dummy_inputs)}, '
f'the length after filtering: {len(dummy_inputs_filter)}')
dummy_inputs = dummy_inputs_filter

with torch.no_grad():
model.eval()
with replace_call():
traced_model = torch.jit.trace(
model, dummy_inputs, strict=strict)
model, tuple(dummy_inputs), strict=strict)
torch.jit.save(traced_model, output)

if validation:
@@ -249,6 +314,10 @@ class TorchModelExporter(Exporter):
outputs = numpify_tensor_nested(outputs)
outputs_origin = model.forward(*dummy_inputs)
outputs_origin = numpify_tensor_nested(outputs_origin)
if isinstance(outputs, dict):
outputs = list(outputs.values())
if isinstance(outputs_origin, dict):
outputs_origin = list(outputs_origin.values())
tols = {}
if rtol is not None:
tols['rtol'] = rtol


+ 36
- 7
modelscope/hub/api.py View File

@@ -23,7 +23,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_MESSAGE,
API_RESPONSE_FIELD_USERNAME,
DEFAULT_CREDENTIALS_PATH,
MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS,
MODELSCOPE_ENVIRONMENT,
MODELSCOPE_USERNAME, ONE_YEAR_SECONDS,
Licenses, ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, NoValidRevisionError,
@@ -38,8 +39,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
DEFAULT_REPOSITORY_REVISION,
MASTER_MODEL_BRANCH, DatasetFormations,
DatasetMetaFormats, DownloadMode,
ModelFile)
DatasetMetaFormats, DownloadChannel,
DownloadMode, ModelFile)
from modelscope.utils.logger import get_logger
from .utils.utils import (get_endpoint, get_release_datetime,
model_id_to_group_owner_name)
@@ -382,10 +383,11 @@ class HubApi:
logger.info('Model revision not specified, use default: %s in development mode' % revision)
if revision not in branches and revision not in tags:
raise NotExistError('The model: %s has no branch or tag : %s .' % revision)
logger.info('Development mode use revision: %s' % revision)
else:
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies)
if revision is None:
if revision is None: # user not specified revision, use latest revision before release time
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies)
if len(revisions) == 0:
raise NoValidRevisionError('The model: %s has no valid revision!' % model_id)
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
@@ -393,9 +395,13 @@ class HubApi:
revision = revisions[0]
logger.info('Model revision not specified, use the latest revision: %s' % revision)
else:
# use user-specified revision
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=current_timestamp, use_cookies=False if cookies is None else cookies)
if revision not in revisions:
raise NotExistError(
'The model: %s has no revision: %s !' % (model_id, revision))
logger.info('Use user-specified model revision: %s' % revision)
return revision

def get_model_branches_and_tags(
@@ -640,6 +646,25 @@ class HubApi:
def check_local_cookies(self, use_cookies) -> CookieJar:
return self._check_cookie(use_cookies=use_cookies)

def dataset_download_uv(self, dataset_name: str, namespace: str):
if not dataset_name or not namespace:
raise ValueError('dataset_name or namespace cannot be empty!')

# get channel and user_name
channel = DownloadChannel.LOCAL.value
user_name = ''
if MODELSCOPE_ENVIRONMENT in os.environ:
channel = os.environ[MODELSCOPE_ENVIRONMENT]
if MODELSCOPE_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_USERNAME]

url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}'
cookies = ModelScopeConfig.get_cookies()
r = requests.post(url, cookies=cookies, headers=self.headers)
resp = r.json()
raise_on_error(resp)
return resp['Message']


class ModelScopeConfig:
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
@@ -755,14 +780,18 @@ class ModelScopeConfig:
env = 'custom'
if MODELSCOPE_ENVIRONMENT in os.environ:
env = os.environ[MODELSCOPE_ENVIRONMENT]
user_name = 'unknown'
if MODELSCOPE_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_USERNAME]

ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % (
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
__version__,
platform.python_version(),
ModelScopeConfig.get_user_session_id(),
platform.platform(),
platform.processor(),
env,
user_name,
)
if isinstance(user_agent, dict):
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())


+ 1
- 0
modelscope/hub/constants.py View File

@@ -18,6 +18,7 @@ API_RESPONSE_FIELD_EMAIL = 'Email'
API_RESPONSE_FIELD_MESSAGE = 'Message'
MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60




+ 15
- 0
modelscope/hub/utils/utils.py View File

@@ -5,6 +5,8 @@ import os
from datetime import datetime
from typing import Optional

import requests

from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
@@ -85,3 +87,16 @@ def file_integrity_validation(file_path, expected_sha256):
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
logger.error(msg)
raise FileIntegrityError(msg)


def create_library_statistics(method: str, name: str, cn_name: Optional[str]):
try:
from modelscope.hub.api import ModelScopeConfig
path = f'{get_endpoint()}/api/v1/statistics/library'
headers = {'user-agent': ModelScopeConfig.get_user_agent()}
params = {'Method': method, 'Name': name, 'CnName': cn_name}
r = requests.post(path, params=params, headers=headers)
r.raise_for_status()
except Exception:
pass
return

+ 6
- 0
modelscope/metainfo.py View File

@@ -389,6 +389,7 @@ class Preprocessors(object):

# multi-modal preprocessor
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'
clip_preprocessor = 'clip-preprocessor'
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor'

# science preprocessor
@@ -428,6 +429,8 @@ class Metrics(object):
image_inpainting_metric = 'image-inpainting-metric'
# metric for ocr
NED = 'ned'
# metric for cross-modal retrieval
inbatch_recall = 'inbatch_recall'
# metric for referring-video-object-segmentation task
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'

@@ -474,6 +477,9 @@ class Hooks(object):
# Compression
SparsityHook = 'SparsityHook'

# CLIP logit_scale clamp
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'


class LR_Schedulers(object):
"""learning rate scheduler is defined here


+ 1
- 0
modelscope/metrics/builder.py View File

@@ -24,6 +24,7 @@ class MetricKeys(object):
ROUGE_1 = 'rouge-1'
ROUGE_L = 'rouge-l'
NED = 'ned' # ocr metric
BatchAcc = 'inbatch_t2i_recall_at_1'


task_default_metrics = {


+ 55
- 0
modelscope/metrics/inbatch_recall_metric.py View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict

import numpy as np
import torch

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group, module_name=Metrics.inbatch_recall)
class InbatchRecallMetric(Metric):
"""The metric computation class for in-batch retrieval classes.

This metric class calculates in-batch image recall@1 for each input batch.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inbatch_t2i_hitcnts = []
self.batch_sizes = []

def add(self, outputs: Dict, inputs: Dict):
image_features = outputs[OutputKeys.IMG_EMBEDDING]
text_features = outputs[OutputKeys.TEXT_EMBEDDING]

assert type(image_features) == torch.Tensor and type(
text_features) == torch.Tensor

with torch.no_grad():
logits_per_image = image_features @ text_features.t()
logits_per_text = logits_per_image.t()
batch_size = logits_per_image.shape[0]

ground_truth = torch.arange(batch_size).long()
ground_truth = ground_truth.to(image_features.device)

inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth
).sum().float().item()

self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt)
self.batch_sizes.append(batch_size)

def evaluate(self):
assert len(self.inbatch_t2i_hitcnts) == len(
self.batch_sizes) and len(self.batch_sizes) > 0
return {
MetricKeys.BatchAcc:
sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes)
}

+ 9
- 0
modelscope/models/base/base_model.py View File

@@ -131,6 +131,8 @@ class Model(ABC):

if not hasattr(model, 'cfg'):
model.cfg = cfg

model.name = model_name_or_path
return model

def save_pretrained(self,
@@ -161,5 +163,12 @@ class Model(ABC):
assert config is not None, 'Cannot save the model because the model config is empty.'
if isinstance(config, Config):
config = config.to_dict()
if 'preprocessor' in config and config['preprocessor'] is not None:
if 'mode' in config['preprocessor']:
config['preprocessor']['mode'] = 'inference'
elif 'val' in config['preprocessor'] and 'mode' in config[
'preprocessor']['val']:
config['preprocessor']['val']['mode'] = 'inference'

save_pretrained(self, target_folder, save_checkpoint_names,
save_function, config, **kwargs)

+ 38
- 124
modelscope/models/multi_modal/clip/model.py View File

@@ -15,15 +15,13 @@

import os
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Tuple, Union
from typing import Any, Dict, Tuple, Union

import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

from modelscope.metainfo import Models
from modelscope.models import TorchModel
@@ -351,11 +349,13 @@ class CLIP(nn.Module):
text_num_hidden_layers: int,
text_type_vocab_size: int,
tokenizer: FullTokenizer,
# vision_head_width, added this param for ViT-H
vision_head_width: int = 64,
):
super().__init__()

if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
vision_heads = vision_width * 32 // vision_head_width
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
@@ -363,7 +363,7 @@ class CLIP(nn.Module):
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
vision_heads = vision_width // vision_head_width
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
@@ -506,21 +506,6 @@ def convert_weights(model: nn.Module):
model.apply(_convert_weights_to_fp16)


def _convert_to_rgb(image):
return image.convert('RGB')


def image_transform(image_size=224):
transform = Compose([
_convert_to_rgb,
Resize((image_size, image_size)),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
return transform


@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip)
class CLIPForMultiModalEmbedding(TorchModel):

@@ -540,72 +525,40 @@ class CLIPForMultiModalEmbedding(TorchModel):

with open(vision_model_config_file,
'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
self.model_info = json.load(fv)
for k, v in json.load(ft).items():
model_info[k] = v

# image preprocess
self.img_preprocess = image_transform(model_info['image_resolution'])
self.model_info[k] = v

# text tokenizer
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}'
self.tokenizer = FullTokenizer(vocab_file=vocab_file)

# initialize the model
self.clip_model = CLIP(**model_info, tokenizer=self.tokenizer)
self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer)
convert_weights(self.clip_model)

# restore the pretrained weight
checkpoint = torch.load(
f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu')
sd = checkpoint['state_dict']
sd = checkpoint[
'state_dict'] if 'state_dict' in checkpoint else checkpoint
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
# support the finetuned model
if next(iter(sd.items()))[0].startswith('clip_model'):
sd = {k[len('clip_model.'):]: v for k, v in sd.items()}
self.clip_model.load_state_dict(sd)
self.clip_model.eval()

# place the model
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if self.device == 'cuda':
self.device = 'cuda:{}'.format(int(os.environ.get(
'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
self.clip_model.to(self.device)
logger.info('Use GPU for inference')
logger.info('Use GPU {} for finetuning & inference'.format(
int(os.environ.get('LOCAL_RANK', 0))))
else:
self.clip_model.float()
logger.info('Use CPU for inference')

def tokenize(self,
texts: Union[str, List[str]],
context_length: int = 52) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all baseline models use 24 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]

all_tokens = []
for text in texts:
all_tokens.append(
[self.tokenizer.vocab['[CLS]']]
+ self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text))[:context_length - 2]
+ [self.tokenizer.vocab['[SEP]']])

result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

for i, tokens in enumerate(all_tokens):
assert len(tokens) <= context_length
result[i, :len(tokens)] = torch.tensor(tokens)

return result
logger.info('Use CPU for finetuning & inference')

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
from modelscope.outputs import OutputKeys
@@ -613,75 +566,36 @@ class CLIPForMultiModalEmbedding(TorchModel):
OutputKeys.IMG_EMBEDDING: None,
OutputKeys.TEXT_EMBEDDING: None
}
if 'img' in input and input['img'] is not None:
image_input = input['img']

# single image input
if isinstance(image_input, Image.Image):
image_tensor = self.img_preprocess(image_input).unsqueeze(0)
# multi images input
elif isinstance(image_input, list):
if all([isinstance(elem, Image.Image)
for elem in image_input]):
image_tensor = torch.stack(
[self.img_preprocess(elem) for elem in image_input],
dim=0)
else:
unsupported_elem_type = [
type(elem) for elem in image_input
if not isinstance(elem, Image.Image)
][0]
raise TypeError(
f'img should be PIL.Image or List[PIL.Image], \
but got a List containing one {unsupported_elem_type}'
)
# others
else:
raise TypeError(
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}'
)

image_tensor = image_tensor.to(self.device)

with torch.no_grad():
mode = input.get('mode', ModeKeys.INFERENCE)

# encode the image
if 'img' in input and isinstance(input['img'], torch.Tensor):
image_tensor = input['img'].to(self.device)
if image_tensor.dim() == 5 and image_tensor.shape[1] == 1:
image_tensor = image_tensor.squeeze(1)

with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
image_features = self.clip_model.encode_image(image_tensor)
image_features /= image_features.norm(
dim=-1, keepdim=True) # l2-normalize

output[OutputKeys.IMG_EMBEDDING] = image_features

if 'text' in input and input['text'] is not None:
text_input = input['text']

# single text input
if isinstance(text_input, str):
text_tensor = self.tokenize(text_input)
# multi texts input
elif isinstance(text_input, list):
if all([isinstance(elem, str) for elem in text_input]):
text_tensor = self.tokenize(text_input)
else:
unsupported_elem_type = [
type(elem) for elem in text_input
if not isinstance(elem, str)
][0]
raise TypeError(
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}'
)
# others
else:
raise TypeError(
f'text should be str or List[str], but got {type(text_input)}'
)

text_tensor = text_tensor.to(self.device)

with torch.no_grad():
if 'text' in input and isinstance(input['text'], torch.Tensor):
text_tensor = input['text'].to(self.device)
if text_tensor.dim() == 3 and text_tensor.shape[1] == 1:
text_tensor = text_tensor.squeeze(1)

with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
text_features = self.clip_model.encode_text(text_tensor)
text_features /= text_features.norm(
dim=-1, keepdim=True) # l2-normalize
output[OutputKeys.TEXT_EMBEDDING] = text_features

if mode == ModeKeys.TRAIN:
output['logit_scale'] = (self.clip_model.logit_scale
* 1.0).exp().mean()

return output

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 13
- 0
modelscope/models/multi_modal/ofa/configuration_ofa.py View File

@@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig):
entangle_position_embedding=False,
interpolate_position=False,
orig_patch_image_size=224,
share_attn_bias=False,
use_image_feature=True,
disable_entangle=False,
use_ofasys=False,
vit_type='vit_base',
vit_drop_path_rate=0.0,
**kwargs):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
@@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig):
self.interpolate_position = interpolate_position
self.orig_patch_image_size = orig_patch_image_size

self.share_attn_bias = share_attn_bias
self.use_image_feature = use_image_feature
self.disable_entangle = disable_entangle
self.use_ofasys = use_ofasys
self.vit_type = vit_type
self.vit_drop_path_rate = vit_drop_path_rate

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,


+ 203
- 141
modelscope/models/multi_modal/ofa/modeling_ofa.py View File

@@ -35,6 +35,8 @@ from transformers.utils import logging
from .configuration_ofa import OFAConfig
from .generate import utils
from .resnet import ResNet
from .utils.utils import DropPath
from .vit import vit_base, vit_huge, vit_large, vit_large_336

logger = logging.get_logger(__name__)

@@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList):
yield m


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
x (`nn.Modules`): input nn layers.
drop_prob (`float`): drop path ratio.
training (`bool`): whether is training or inference.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (1, x.shape[1], 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
drop_prob: drop path ratio.
"""

def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)


class OFAAttention(nn.Module):
r"""
Multi-headed attention, with additional implementation for NormFormer.
@@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel):
self.padding_idx)

if config.add_type_embedding:
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
if config.use_image_feature:
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
else:
self.type_embedding = Embedding(1, embed_dim, padding_idx=None)
else:
self.type_embedding = None

if config.resnet_type == 'resnet18':
self.embed_images = ResNet(
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet34':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet50':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet101':
self.embed_images = ResNet(
[3, 4, 23], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet152':
self.embed_images = ResNet(
[3, 8, 36], drop_path_rate=config.resnet_drop_path_rate)
else:
raise NotImplementedError
if config.use_image_feature:
if config.use_ofasys:
vit_backbone = {
'vit_base': vit_base,
'vit_large': vit_large,
'vit_large_336': vit_large_336,
'vit_huge': vit_huge,
}[config.vit_type]
self.embed_images = vit_backbone(config.vit_drop_path_rate)

self.image_proj = Linear(1024, embed_dim)
self.image_proj = Linear(self.embed_images.width, embed_dim)

if config.resnet_model_path:
else:
if config.resnet_type == 'resnet18':
self.embed_images = ResNet(
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet34':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet50':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet101':
self.embed_images = ResNet(
[3, 4, 23],
drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet152':
self.embed_images = ResNet(
[3, 8, 36],
drop_path_rate=config.resnet_drop_path_rate)
else:
raise NotImplementedError

self.image_proj = Linear(1024, embed_dim)

if not config.use_ofasys and config.resnet_model_path:
print('load resnet {}'.format(config.resnet_model_path))
resnet_state_dict = torch.load(config.resnet_model_path)
self.embed_images.load_state_dict(resnet_state_dict)
@@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel):

self.embed_positions = Embedding(self.max_source_positions + 2,
embed_dim)
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)

if config.use_image_feature:
self.embed_image_positions = Embedding(
config.image_bucket_size**2 + 1, embed_dim)
if not config.use_ofasys:
self.pos_ln = LayerNorm(embed_dim)

if config.use_image_feature:
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / self.num_attention_heads
* config.attn_scale_factor)**-0.5
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)

if not (config.use_ofasys and config.entangle_position_embedding):
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)

if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
@@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel):
self.token_bucket_size = config.token_bucket_size
token_num_rel_dis = 2 * config.token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
self.share_attn_bias = config.share_attn_bias
num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers
self.token_rel_pos_table_list = nn.ModuleList([
Embedding(
token_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.encoder_layers)
for _ in range(num_rel_pos_tables)
])

self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.encoder_layers)
])
if config.use_image_feature:
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(
config.image_bucket_size, image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis,
self.num_attention_heads,
zero_init=True) for _ in range(num_rel_pos_tables)
])

self.register_buffer('image_rp_bucket', image_rp_bucket)

if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
@@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel):
self.layernorm_embedding = None

self.register_buffer('token_rp_bucket', token_rp_bucket)
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.entangle_position_embedding = config.entangle_position_embedding

self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
self.use_ofasys = config.use_ofasys

def get_input_embeddings(self):
r"""
@@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel):
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))

pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
if self.use_ofasys:
if patch_images is not None:
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
else:
pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)

def build_abs_pos_bias(pos_embed):
batch_size, seq_length = pos_embed.size(0), pos_embed.size(1)
if not (self.use_ofasys and self.entangle_position_embedding):
pos_q = self.pos_q_linear(pos_embed).view(
batch_size, seq_length, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
batch_size, seq_length, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
seq_length,
seq_length,
dtype=pos_embed.dtype,
device=pos_embed.device)
return abs_pos_bias

pos_q = self.pos_q_linear(pos_embed).view(
x.size(0), x.size(1), self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
x.size(0), x.size(1), self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
abs_pos_bias = build_abs_pos_bias(pos_embed)

# expand attention_mask
if has_pads:
@@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel):
if output_hidden_states:
encoder_states += (x, )
self_attn_bias = abs_pos_bias.clone()

real_idx = 0 if self.share_attn_bias else idx

self_attn_bias[:, :, -input_ids.size(1):,
-input_ids.size(1):] += self.get_rel_pos_bias(
input_ids, idx)
input_ids, real_idx)
if patch_images_2 is not None:
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
self.get_image_rel_pos_bias(image_position_ids_2, idx)
self.get_image_rel_pos_bias(image_position_ids_2, real_idx)
self_attn_bias[:, :,
image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa
image_num_patches_2:image_num_patches_2 + image_num_patches] += \
self.get_image_rel_pos_bias(image_position_ids, idx) # noqa
self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa
elif patch_images is not None:
self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
self.get_image_rel_pos_bias(image_position_ids, real_idx)
self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1))

hidden_outputs = layer(
@@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel):
self._future_mask = torch.empty(0)
self.share_input_output_embed = config.share_decoder_input_output_embed
self.num_attention_heads = config.decoder_attention_heads
self.use_ofasys = config.use_ofasys
self.disable_entangle = config.disable_entangle

if embed_tokens is not None:
self.embed_tokens = embed_tokens
@@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel):
else:
self.layernorm_embedding = None

if config.use_ofasys:
if config.add_type_embedding:
self.type_embedding = Embedding(
1, self.embed_dim, padding_idx=None)
else:
self.type_embedding = None

self.window_size = config.code_image_size // 8

self.embed_positions = Embedding(self.max_target_positions + 2,
self.embed_dim)
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
self.embed_dim)
self.pos_ln = LayerNorm(self.embed_dim)
self.image_pos_ln = LayerNorm(self.embed_dim)

if not config.use_ofasys:
self.embed_image_positions = Embedding(
config.image_bucket_size**2 + 1, self.embed_dim)
if not config.use_ofasys:
self.pos_ln = LayerNorm(self.embed_dim)
self.image_pos_ln = LayerNorm(self.embed_dim)
self.pos_scaling = float(self.embed_dim / self.num_attention_heads
* config.attn_scale_factor)**-0.5
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)

if not (config.use_ofasys and config.entangle_position_embedding):
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)

self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)

@@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel):
self.token_bucket_size = config.token_bucket_size
token_num_rel_dis = 2 * config.token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)

self.share_attn_bias = config.share_attn_bias
num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers
self.token_rel_pos_table_list = nn.ModuleList([
Embedding(
token_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.decoder_layers)
for _ in range(num_rel_pos_tables)
])

self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
image_position_idx = torch.cat(
[torch.tensor([0]), image_position_idx.view(-1)])
image_position_idx = torch.cat(
[image_position_idx,
torch.tensor([1024] * 768)])
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.decoder_layers)
])
if config.use_image_feature:
if not config.use_ofasys:
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size - 1) * (
2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(
config.image_bucket_size, image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
image_position_idx = torch.cat(
[torch.tensor([0]),
image_position_idx.view(-1)])
image_position_idx = torch.cat(
[image_position_idx,
torch.tensor([1024] * 768)])
self.register_buffer('image_position_idx', image_position_idx)

self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis,
self.num_attention_heads,
zero_init=True) for _ in range(num_rel_pos_tables)
])
self.register_buffer('image_rp_bucket', image_rp_bucket)

self.register_buffer('token_rp_bucket', token_rp_bucket)
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.register_buffer('image_position_idx', image_position_idx)
self.entangle_position_embedding = config.entangle_position_embedding

self.gradient_checkpointing = False
@@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel):

batch_size = tgt_pos_embed.size(0)
tgt_len = tgt_pos_embed.size(1)
tgt_pos_embed = self.image_pos_ln(
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
if not self.use_ofasys:
tgt_pos_embed = self.image_pos_ln(
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)

if src_pos_embed is not None:
src_len = src_pos_embed.size(1)
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
if not (self.entangle_position_embedding and self.use_ofasys):
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
tgt_len,
src_len,
dtype=tgt_pos_embed.dtype,
device=tgt_pos_embed.device)
else:
src_len = tgt_pos_embed.size(1)
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
# batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1)
if not (self.entangle_position_embedding and self.use_ofasys):
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
tgt_len,
tgt_len,
dtype=tgt_pos_embed.dtype,
device=tgt_pos_embed.device)

return abs_pos_bias

@@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel):
past_key_values) > 0 else None

self_attn_bias = self_abs_pos_bias.clone()
real_idx = 0 if self.share_attn_bias else idx
if code_masks is None or not code_masks.any():
self_attn_bias += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
elif code_masks is not None and code_masks.all():
self_attn_bias += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
else:
self_attn_bias[~code_masks] += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
self_attn_bias = self_attn_bias.reshape(
-1,
*self_attn_bias.size()[-2:])
@@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel):

self.encoder = OFAEncoder(config, shared)
self.decoder = OFADecoder(config, shared)
self.use_ofasys = config.use_ofasys

# Initialize weights and apply final processing
self.post_init()


+ 40
- 0
modelscope/models/multi_modal/ofa/utils/utils.py View File

@@ -2,6 +2,7 @@
from typing import Optional

import torch
import torch.nn as nn


def expand_mask(mask: torch.Tensor,
@@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor,
src_len).to(dtype)
return expanded_mask.masked_fill(expanded_mask.bool(),
torch.finfo(dtype).min)


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
x (`nn.Modules`): input nn layers.
drop_prob (`float`): drop path ratio.
training (`bool`): whether is training or inference.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (1, x.shape[1], 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Args:
drop_prob: drop path ratio.
"""

def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)

+ 155
- 0
modelscope/models/multi_modal/ofa/vit.py View File

@@ -0,0 +1,155 @@
from collections import OrderedDict

import torch
import torch.nn.functional as F
from fairseq.modules import LayerNorm
from torch import nn

from .utils.utils import DropPath

__all__ = [
'vit_base',
'vit_large',
'vit_large_336',
'vit_huge',
]


class QuickGELU(nn.Module):

def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):

def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
drop_path_rate=0.0):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([
('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model)),
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.drop_path = DropPath(drop_path_rate)

def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None else None)
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.drop_path(self.attention(self.ln_1(x)))
x = x + self.drop_path(self.mlp(self.ln_2(x)))
return x


class Transformer(nn.Module):

def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
drop_path_rate: float = 0.0,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate)
for _ in range(layers)
])

def forward(self, x: torch.Tensor):
return self.resblocks(x)


class VisionTransformer(nn.Module):

def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
drop_path_rate: float = 0.0,
):
super().__init__()
self.input_resolution = input_resolution
self.patch_size = patch_size
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)

scale = width**-0.5
self.width = width
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(
width, layers, heads, drop_path_rate=drop_path_rate)

def forward(self, x: torch.Tensor):
resolution = x.shape[-2]
height, width = x.shape[-2] // self.patch_size, x.shape[
-1] // self.patch_size
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]

if resolution != self.input_resolution:
old_pe = self.positional_embedding[1:]
patch_num = self.input_resolution // self.patch_size
old_pe = old_pe.reshape(1, patch_num, patch_num,
-1).permute(0, 3, 1, 2)
new_pe = F.interpolate(
old_pe, size=(height, width), mode='bilinear')
new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1)
x = x + new_pe.to(x.dtype)
else:
x = x + self.positional_embedding[1:].to(x.dtype)
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

bz, seq, hidden = x.shape
x = x.transpose(1, 2).reshape(bz, hidden, height, width)

return x


def vit_base(drop_path_rate: float = 0.0):
return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate)


def vit_large(drop_path_rate: float = 0.0):
return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate)


def vit_large_336(drop_path_rate: float = 0.0):
return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate)


def vit_huge(drop_path_rate: float = 0.0):
return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate)

+ 22
- 3
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import re
import string
from functools import partial
from os import path as osp
@@ -53,8 +54,11 @@ class OfaForAllTasks(TorchModel):
raise NotImplementedError
# there is some diff between here and our ofa code,
# there will be no need to use param: use_bpe
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
if not model.use_ofasys:
self.tokenizer.add_tokens(
['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(
['<bin_{}>'.format(i) for i in range(1000)])
self.cfg.update({'num_bins': 1000, 'num_codes': 8192})
self.batch_size = self.cfg.model.get('batch_size', 1)
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
@@ -107,6 +111,8 @@ class OfaForAllTasks(TorchModel):
Tasks.text_classification: inference_d[self.gen_type],
Tasks.image_classification: inference_d[self.gen_type],
}
pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))'
self.pattern = re.compile(pattern_str)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input = move_to_device(input, self.model.device)
@@ -132,8 +138,18 @@ class OfaForAllTasks(TorchModel):
caption = input[OutputKeys.CAPTION]
result_l = list()
for cap in caption:
result_l.append(cap.translate(self.transtab).strip())
if self.language == 'en':
result_l.append(cap.translate(self.transtab).strip())
else:
result_l.append(cap)
input[OutputKeys.CAPTION] = result_l
if self.gen_type == 'generation' and self.language in [
'zh', 'cn'
] and self.cfg.task != Tasks.visual_grounding:
ret_l = list()
for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]:
ret_l.append(self.detokenizer(text))
input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l
return input

def _text_gen_inference(self, input):
@@ -311,3 +327,6 @@ class OfaForAllTasks(TorchModel):
save_function=partial(save_function, with_meta=False),
config=config,
**kwargs)

def detokenizer(self, text):
return self.pattern.sub('', text)

+ 1
- 0
modelscope/models/nlp/bert/text_ranking.py View File

@@ -36,6 +36,7 @@ class BertForTextRanking(BertForSequenceClassification):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
*args,
**kwargs) -> AttentionTextClassificationModelOutput:
outputs = self.base_model.forward(
input_ids=input_ids,


+ 1
- 0
modelscope/models/nlp/structbert/text_classification.py View File

@@ -109,6 +109,7 @@ class SbertForSequenceClassification(SbertPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
*args,
**kwargs):
r"""
Args:


+ 3
- 0
modelscope/models/science/unifold/dataset.py View File

@@ -1,3 +1,6 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import copy
import logging
import os


+ 3
- 0
modelscope/models/science/unifold/model.py View File

@@ -1,3 +1,6 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import argparse
import os
from typing import Any


+ 3
- 0
modelscope/models/science/unifold/modules/__init__.py View File

@@ -0,0 +1,3 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
"""Unifold Modules."""

+ 2
- 0
modelscope/msdatasets/ms_dataset.py View File

@@ -274,6 +274,8 @@ class MsDataset:
try:
api.on_dataset_download(
dataset_name=download_dataset, namespace=namespace)
api.dataset_download_uv(
dataset_name=download_dataset, namespace=namespace)
except Exception as e:
logger.error(e)



+ 19
- 15
modelscope/outputs/outputs.py View File

@@ -69,11 +69,23 @@ TASK_OUTPUTS = {
# face 2d keypoint result for single sample
# {
# "keypoints": [
# [x1, y1]*106
# [[x, y]*106],
# [[x, y]*106],
# [[x, y]*106],
# ],
# "poses": [pitch, roll, yaw]
# "poses": [
# [pitch, roll, yaw],
# [pitch, roll, yaw],
# [pitch, roll, yaw],
# ],
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ]
# }
Tasks.face_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.POSES],
Tasks.face_2d_keypoints:
[OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES],

# face detection result for single sample
# {
@@ -479,17 +491,8 @@ TASK_OUTPUTS = {
# word segmentation result for single sample
# {
# "output": "今天 天气 不错 , 适合 出去 游玩"
# "labels": [
# {'word': '今天', 'label': 'PROPN'},
# {'word': '天气', 'label': 'PROPN'},
# {'word': '不错', 'label': 'VERB'},
# {'word': ',', 'label': 'NUM'},
# {'word': '适合', 'label': 'NOUN'},
# {'word': '出去', 'label': 'PART'},
# {'word': '游玩', 'label': 'ADV'},
# ]
# }
Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS],
Tasks.word_segmentation: [OutputKeys.OUTPUT],

# TODO @wenmeng.zwm support list of result check
# named entity recognition result for single sample
@@ -699,8 +702,9 @@ TASK_OUTPUTS = {
# "text_embedding": np.array with shape [1, D],
# "caption": "this is an image caption text."
# }
Tasks.generative_multi_modal_embedding:
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION],
Tasks.generative_multi_modal_embedding: [
OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION
],

# multi-modal similarity result for single sample
# {


+ 4
- 1
modelscope/pipelines/base.py View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, List, Mapping, Union

import numpy as np

from modelscope.hub.utils.utils import create_library_statistics
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.outputs import TASK_OUTPUTS
@@ -151,7 +152,9 @@ class Pipeline(ABC):
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function

for single_model in self.models:
if hasattr(single_model, 'name'):
create_library_statistics('pipeline', single_model.name, None)
# place model to cpu or gpu
if (self.model or (self.has_multiple_models and self.models[0])):
if not self._model_prepare:


+ 2
- 3
modelscope/pipelines/builder.py View File

@@ -93,9 +93,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_resnet50_live-category'),
Tasks.video_category: (Pipelines.video_category,
'damo/cv_resnet50_video-category'),
Tasks.multi_modal_embedding:
(Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-large-patch14_zh'),
Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-base-patch16_zh'),
Tasks.generative_multi_modal_embedding:
(Pipelines.generative_multi_modal_embedding,
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'


+ 205
- 7
modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py View File

@@ -1,12 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
from typing import Any

import cv2
import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .base import EasyCVPipeline

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_2d_keypoints, module_name=Pipelines.face_2d_keypoints)
@@ -29,18 +39,206 @@ class Face2DKeypointsPipeline(EasyCVPipeline):
*args,
**kwargs)

# face detect pipeline
det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'
self.face_detection = pipeline(
Tasks.face_detection, model=det_model_id)

def show_result(self, img, points, scale=2, save_path=None):
return self.predict_op.show_result(img, points, scale, save_path)

def _choose_face(self, det_result, min_face=10):
"""
choose face with maximum area
Args:
det_result: output of face detection pipeline
min_face: minimum size of valid face w/h
"""
bboxes = np.array(det_result[OutputKeys.BOXES])
landmarks = np.array(det_result[OutputKeys.KEYPOINTS])
if bboxes.shape[0] == 0:
logger.warn('No face detected!')
return None
# face idx with enough size
face_idx = []
for i in range(bboxes.shape[0]):
box = bboxes[i]
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face:
face_idx += [i]
if len(face_idx) == 0:
logger.warn(
f'Face size not enough, less than {min_face}x{min_face}!')
return None
bboxes = bboxes[face_idx]
landmarks = landmarks[face_idx]

return bboxes, landmarks

def expend_box(self, box, w, h, scalex=0.3, scaley=0.5):
x1 = box[0]
y1 = box[1]
wb = box[2] - x1
hb = box[3] - y1
deltax = int(wb * scalex)
deltay1 = int(hb * scaley)
deltay2 = int(hb * scalex)
x1 = x1 - deltax
y1 = y1 - deltay1
if x1 < 0:
deltax = deltax + x1
x1 = 0
if y1 < 0:
deltay1 = deltay1 + y1
y1 = 0
x2 = x1 + wb + 2 * deltax
y2 = y1 + hb + deltay1 + deltay2
x2 = np.clip(x2, 0, w - 1)
y2 = np.clip(y2, 0, h - 1)
return [x1, y1, x2, y2]

def rotate_point(self, angle, center, landmark):
rad = angle * np.pi / 180.0
alpha = np.cos(rad)
beta = np.sin(rad)
M = np.zeros((2, 3), dtype=np.float32)
M[0, 0] = alpha
M[0, 1] = beta
M[0, 2] = (1 - alpha) * center[0] - beta * center[1]
M[1, 0] = -beta
M[1, 1] = alpha
M[1, 2] = beta * center[0] + (1 - alpha) * center[1]

landmark_ = np.asarray([(M[0, 0] * x + M[0, 1] * y + M[0, 2],
M[1, 0] * x + M[1, 1] * y + M[1, 2])
for (x, y) in landmark])
return M, landmark_

def rotate_crop_img(self, img, pts, M):
imgT = cv2.warpAffine(img, M, (int(img.shape[1]), int(img.shape[0])))

x1 = pts[5][0]
x2 = pts[5][0]
y1 = pts[5][1]
y2 = pts[5][1]
for i in range(0, 9):
x1 = min(x1, pts[i][0])
x2 = max(x2, pts[i][0])
y1 = min(y1, pts[i][1])
y2 = max(y2, pts[i][1])

height, width, _ = imgT.shape
x1 = min(max(0, int(x1)), width)
y1 = min(max(0, int(y1)), height)
x2 = min(max(0, int(x2)), width)
y2 = min(max(0, int(y2)), height)
sub_imgT = imgT[y1:y2, x1:x2]

return sub_imgT, imgT, [x1, y1, x2, y2]

def crop_img(self, imgT, pts):
enlarge_ratio = 1.1

x1 = np.min(pts[:, 0])
x2 = np.max(pts[:, 0])
y1 = np.min(pts[:, 1])
y2 = np.max(pts[:, 1])
w = x2 - x1 + 1
h = y2 - y1 + 1
x1 = int(x1 - (enlarge_ratio - 1.0) / 2.0 * w)
y1 = int(y1 - (enlarge_ratio - 1.0) / 2.0 * h)
x1 = max(0, x1)
y1 = max(0, y1)

new_w = int(enlarge_ratio * w)
new_h = int(enlarge_ratio * h)
new_x1 = x1
new_y1 = y1
new_x2 = new_x1 + new_w
new_y2 = new_y1 + new_h

height, width, _ = imgT.shape

new_x1 = min(max(0, new_x1), width)
new_y1 = min(max(0, new_y1), height)
new_x2 = max(min(width, new_x2), 0)
new_y2 = max(min(height, new_y2), 0)

sub_imgT = imgT[new_y1:new_y2, new_x1:new_x2]

return sub_imgT, [new_x1, new_y1, new_x2, new_y2]

def __call__(self, inputs) -> Any:
outputs = self.predict_op(inputs)
img = LoadImage.convert_to_ndarray(inputs)
h, w, c = img.shape
img_rgb = copy.deepcopy(img)
img_rgb = img_rgb[:, :, ::-1]
det_result = self.face_detection(img_rgb)

bboxes = np.array(det_result[OutputKeys.BOXES])
if bboxes.shape[0] == 0:
logger.warn('No face detected!')
results = {
OutputKeys.KEYPOINTS: [],
OutputKeys.POSES: [],
OutputKeys.BOXES: []
}
return results

boxes, keypoints = self._choose_face(det_result)

output_boxes = []
output_keypoints = []
output_poses = []
for index, box_ori in enumerate(boxes):
box = self.expend_box(box_ori, w, h, scalex=0.1, scaley=0.1)
y0 = int(box[1])
y1 = int(box[3])
x0 = int(box[0])
x1 = int(box[2])
sub_img = img[y0:y1, x0:x1]

keypoint = keypoints[index]
pts = [[keypoint[0], keypoint[1]], [keypoint[2], keypoint[3]],
[keypoint[4], keypoint[5]], [keypoint[6], keypoint[7]],
[keypoint[8], keypoint[9]], [box[0], box[1]],
[box[2], box[1]], [box[0], box[3]], [box[2], box[3]]]
# radian
angle = math.atan2((pts[1][1] - pts[0][1]),
(pts[1][0] - pts[0][0]))
# angle
theta = angle * (180 / np.pi)

center = [w // 2, h // 2]
cx, cy = center
M, landmark_ = self.rotate_point(theta, (cx, cy), pts)
sub_imgT, imgT, bbox = self.rotate_crop_img(img, landmark_, M)

outputs = self.predict_op([sub_imgT])[0]
tmp_keypoints = outputs['point']

for idx in range(0, len(tmp_keypoints)):
tmp_keypoints[idx][0] += bbox[0]
tmp_keypoints[idx][1] += bbox[1]

for idx in range(0, 6):
sub_img, bbox = self.crop_img(imgT, tmp_keypoints)
outputs = self.predict_op([sub_img])[0]
tmp_keypoints = outputs['point']
for idx in range(0, len(tmp_keypoints)):
tmp_keypoints[idx][0] += bbox[0]
tmp_keypoints[idx][1] += bbox[1]

M2, tmp_keypoints = self.rotate_point(-theta, (cx, cy),
tmp_keypoints)

results = [{
OutputKeys.KEYPOINTS: output['point'],
OutputKeys.POSES: output['pose']
} for output in outputs]
output_keypoints.append(np.array(tmp_keypoints))
output_poses.append(np.array(outputs['pose']))
output_boxes.append(np.array(box_ori))

if self._is_single_inputs(inputs):
results = results[0]
results = {
OutputKeys.KEYPOINTS: output_keypoints,
OutputKeys.POSES: output_poses,
OutputKeys.BOXES: output_boxes
}

return results

+ 15
- 7
modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py View File

@@ -1,10 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict
from typing import Any, Dict, Optional, Union

from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

@@ -17,7 +19,10 @@ logger = get_logger()
Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding)
class MultiModalEmbeddingPipeline(Pipeline):

def __init__(self, model: str, device: str = 'gpu'):
def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
@@ -29,14 +34,17 @@ class MultiModalEmbeddingPipeline(Pipeline):
pipe_model = model
else:
raise NotImplementedError('model must be a single str')
pipe_model.eval()
if preprocessor is None:
if isinstance(pipe_model, CLIPForMultiModalEmbedding):
preprocessor = CLIPPreprocessor(pipe_model.model_dir)
else:
raise NotImplementedError

super().__init__(model=pipe_model)

def preprocess(self, input: Input) -> Dict[str, Any]:
return input
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model(input)
return self.model(self.preprocess(input))

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 2
- 2
modelscope/pipelines/nlp/token_classification_pipeline.py View File

@@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline):
chunk['span'] = text[chunk['start']:chunk['end']]
chunks.append(chunk)

# for cws output
# for cws outputs
if len(chunks) > 0 and chunks[0]['type'] == 'cws':
spans = [
chunk['span'] for chunk in chunks if chunk['span'].strip()
]
seg_result = ' '.join(spans)
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []}
outputs = {OutputKeys.OUTPUT: seg_result}

# for ner outputs
else:


+ 3
- 3
modelscope/pipelines/nlp/word_segmentation_pipeline.py View File

@@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline):
chunk['span'] = text[chunk['start']:chunk['end']]
chunks.append(chunk)

# for cws output
# for cws outputs
if len(chunks) > 0 and chunks[0]['type'] == 'cws':
spans = [
chunk['span'] for chunk in chunks if chunk['span'].strip()
]
seg_result = ' '.join(spans)
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []}
outputs = {OutputKeys.OUTPUT: seg_result}

# for ner outpus
# for ner output
else:
outputs = {OutputKeys.OUTPUT: chunks}
return outputs

+ 178
- 2
modelscope/preprocessors/multi_modal.py View File

@@ -3,8 +3,11 @@ import os.path as osp
from io import BytesIO
from typing import Any, Dict, List, Tuple, Union

import json
import torch
from PIL import Image
from timm.data import create_transform
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors
@@ -74,7 +77,7 @@ class OfaPreprocessor(Preprocessor):
data[key] = item
return data

def _ofa_input_compatibility_conversion(self, data):
def _ofa_input_compatibility_conversion(self, data): # fake
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
if isinstance(data['image'], str):
image = load_image(data['image'])
@@ -93,7 +96,6 @@ class OfaPreprocessor(Preprocessor):
data = input
else:
data = self._build_dict(input)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():
@@ -107,6 +109,180 @@ class OfaPreprocessor(Preprocessor):
eos_idx=self.tokenizer.eos_token_id)


def _convert_to_rgb(image):
return image.convert('RGB')


@PREPROCESSORS.register_module(
Fields.multi_modal, module_name=Preprocessors.clip_preprocessor)
class CLIPPreprocessor(Preprocessor):

def __init__(self,
model_dir: str,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
model_dir (str): model path
mode: preprocessor mode (model mode)
"""
super().__init__(*args, **kwargs)
model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
model_dir)
self.mode = mode
# text tokenizer
from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer
if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'],
FullTokenizer):
self.tokenizer = kwargs['tokenizer']
else:
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}'
self.tokenizer = FullTokenizer(vocab_file=vocab_file)
# image preprocessor
if 'resolution' in kwargs and isinstance(kwargs['resolution'], int):
self.image_resolution = kwargs['resolution']
else:
self.image_resolution = json.load(
open('{}/vision_model_config.json'.format(
model_dir)))['image_resolution']
self.img_preprocess = self._build_image_transform()
# key mapping
# specify the input keys, compatible with training and inference whose key names may be different
self.input_keys = {'img': 'img', 'text': 'text'}

def _build_image_transform(self):

if self.mode == ModeKeys.TRAIN:
transform = create_transform(
input_size=self.image_resolution,
scale=(0.9, 1.0),
is_training=True,
color_jitter=None,
auto_augment='original',
interpolation='bicubic',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
transform = Compose(transform.transforms[:-3] + [_convert_to_rgb]
+ transform.transforms[-3:])
else:
transform = Compose([
Resize((self.image_resolution, self.image_resolution),
interpolation=Image.BICUBIC),
_convert_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
return transform

def tokenize(self,
texts: Union[str, List[str]],
context_length: int = 52) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all baseline models use 24 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]

all_tokens = []
for text in texts:
all_tokens.append(
[self.tokenizer.vocab['[CLS]']]
+ self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text))[:context_length - 2]
+ [self.tokenizer.vocab['[SEP]']])

result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

for i, tokens in enumerate(all_tokens):
assert len(tokens) <= context_length
result[i, :len(tokens)] = torch.tensor(tokens)

return result

def set_input_img_key(self, new_key: str):
self.input_keys['img'] = new_key

def set_input_text_key(self, new_key: str):
self.input_keys['text'] = new_key

def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args,
**kwargs) -> Dict[str, Any]:
output = {}
# preprocess the image input
input_img_key = self.input_keys['img']
if input_img_key in input and input[input_img_key] is not None:
image_input = input[input_img_key]

# single image input
if isinstance(image_input, Image.Image):
image_tensor = self.img_preprocess(image_input).unsqueeze(0)
# multi images input
elif isinstance(image_input, list):
if all([isinstance(elem, Image.Image)
for elem in image_input]):
image_tensor = torch.stack(
[self.img_preprocess(elem)
for elem in image_input], # noqa
dim=0) # noqa
else:
unsupported_elem_type = [
type(elem) for elem in image_input
if not isinstance(elem, Image.Image)
][0]
raise TypeError(
f'img should be PIL.Image or List[PIL.Image], \
but got a List containing one {unsupported_elem_type}'
)
# others
else:
raise TypeError(
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}'
)
output['img'] = image_tensor

# preprocess the text input
input_text_key = self.input_keys['text']
if input_text_key in input and input[input_text_key] is not None:
text_input = input[input_text_key]

# single text input
if isinstance(text_input, str):
text_tensor = self.tokenize(text_input)
# multi texts input
elif isinstance(text_input, list):
if all([isinstance(elem, str) for elem in text_input]):
text_tensor = self.tokenize(text_input)
else:
unsupported_elem_type = [
type(elem) for elem in text_input
if not isinstance(elem, str)
][0]
raise TypeError(
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}'
)
# others
else:
raise TypeError(
f'text should be str or List[str], but got {type(text_input)}'
)
output['text'] = text_tensor

return output


@PREPROCESSORS.register_module(
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor)
class MPlugPreprocessor(Preprocessor):


+ 12
- 5
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC):
label=None,
label2id=None,
mode=ModeKeys.INFERENCE,
use_fast=None,
**kwargs):
"""The NLP preprocessor base class.

@@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC):
label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping
if this mapping is not supplied.
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode
use_fast: use the fast version of tokenizer

"""
self.model_dir = model_dir
self.first_sequence = first_sequence
self.second_sequence = second_sequence
self.label = label

self.use_fast = kwargs.pop('use_fast', None)
if self.use_fast is None and os.path.isfile(
self.use_fast = use_fast
if self.use_fast is None and model_dir is None:
self.use_fast = False
elif self.use_fast is None and os.path.isfile(
os.path.join(model_dir, 'tokenizer_config.json')):
with open(os.path.join(model_dir, 'tokenizer_config.json'),
'r') as f:
@@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC):
self.use_fast = False if self.use_fast is None else self.use_fast

self.label2id = label2id
if self.label2id is None:
self.label2id = parse_label_mapping(self.model_dir)
if self.label2id is None and model_dir is not None:
self.label2id = parse_label_mapping(model_dir)
super().__init__(mode, **kwargs)

@property
@@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor):
label: str = 'label',
label2id: dict = None,
mode: str = ModeKeys.INFERENCE,
use_fast: bool = None,
**kwargs):
"""The NLP tokenizer preprocessor base class.

@@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor):
- config.json label2id/id2label
- label_mapping.json
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different.
use_fast: use the fast version of tokenizer
kwargs: These kwargs will be directly fed into the tokenizer.
"""

super().__init__(model_dir, first_sequence, second_sequence, label,
label2id, mode)
label2id, mode, use_fast, **kwargs)
self.model_dir = model_dir
self.tokenize_kwargs = kwargs
self.tokenizer = self.build_tokenizer(model_dir)


+ 82
- 69
modelscope/preprocessors/nlp/token_classification_preprocessor.py View File

@@ -2,6 +2,7 @@

from typing import Any, Dict, Tuple, Union

import numpy as np
import torch

from modelscope.metainfo import Preprocessors
@@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor):
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.first_sequence: str = kwargs.pop('first_sequence', 'tokens')
self.label = kwargs.pop('label', OutputKeys.LABELS)

def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]:
@@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
'is_split_into_words', False)
if 'label2id' in kwargs:
kwargs.pop('label2id')
self.tokenize_kwargs = kwargs

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
@type_assert(object, (str, dict))
def __call__(self, data: Union[dict, str]) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
text = None
labels_list = None
if isinstance(data, str):
# for inference inputs without label
text = data
self.tokenize_kwargs['add_special_tokens'] = False
elif isinstance(data, dict):
# for finetune inputs with label
text = data.get(self.first_sequence)
labels_list = data.get(self.label)
if isinstance(text, list):
self.tokenize_kwargs['is_split_into_words'] = True

input_ids = []
label_mask = []
offset_mapping = []
if self.is_split_into_words:
for offset, token in enumerate(list(data)):
subtoken_ids = self.tokenizer.encode(
token, add_special_tokens=False)
token_type_ids = []
if self.is_split_into_words and self._mode == ModeKeys.INFERENCE:
for offset, token in enumerate(list(text)):
subtoken_ids = self.tokenizer.encode(token,
**self.tokenize_kwargs)
if len(subtoken_ids) == 0:
subtoken_ids = [self.tokenizer.unk_token_id]
input_ids.extend(subtoken_ids)
@@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
else:
if self.tokenizer.is_fast:
encodings = self.tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
**self.tokenize_kwargs)
text, return_offsets_mapping=True, **self.tokenize_kwargs)
attention_mask = encodings['attention_mask']
token_type_ids = encodings['token_type_ids']
input_ids = encodings['input_ids']
word_ids = encodings.word_ids()
for i in range(len(word_ids)):
@@ -137,75 +140,85 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])
else:
encodings = self.tokenizer(
text, add_special_tokens=False, **self.tokenize_kwargs)
encodings = self.tokenizer(text, **self.tokenize_kwargs)
input_ids = encodings['input_ids']
label_mask, offset_mapping = self.get_label_mask_and_offset_mapping(
text)

if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
offset_mapping = offset_mapping[:sum(label_mask)]
if self._mode == ModeKeys.INFERENCE:
if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
offset_mapping = offset_mapping[:sum(label_mask)]

if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]
if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]

if self._mode == ModeKeys.INFERENCE:
input_ids = torch.tensor(input_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)
label_mask = torch.tensor(
label_mask, dtype=torch.bool).unsqueeze(0)

# the token classification
output = {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}

# align the labels with tokenized text
if labels_list is not None:
assert self.label2id is not None
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
label_enumerate_values = [
k for k, v in sorted(
self.label2id.items(), key=lambda item: item[1])
]
for idx, label in enumerate(label_enumerate_values):
if label.startswith('B-') and label.replace(
'B-', 'I-') in label_enumerate_values:
b_to_i_label.append(
label_enumerate_values.index(
label.replace('B-', 'I-')))
else:
b_to_i_label.append(idx)
# the token classification
output = {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}
else:
output = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
}

label_row = [self.label2id[lb] for lb in labels_list]
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
label_ids.append(label_row[word_idx])
else:
if self.label_all_tokens:
label_ids.append(b_to_i_label[label_row[word_idx]])
# align the labels with tokenized text
if labels_list is not None:
assert self.label2id is not None
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
label_enumerate_values = [
k for k, v in sorted(
self.label2id.items(), key=lambda item: item[1])
]
for idx, label in enumerate(label_enumerate_values):
if label.startswith('B-') and label.replace(
'B-', 'I-') in label_enumerate_values:
b_to_i_label.append(
label_enumerate_values.index(
label.replace('B-', 'I-')))
else:
b_to_i_label.append(idx)

label_row = [self.label2id[lb] for lb in labels_list]
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
previous_word_idx = word_idx
labels = label_ids
output['labels'] = labels
elif word_idx != previous_word_idx:
label_ids.append(label_row[word_idx])
else:
if self.label_all_tokens:
label_ids.append(b_to_i_label[label_row[word_idx]])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels = label_ids
output['labels'] = labels
output = {
k: np.array(v) if isinstance(v, list) else v
for k, v in output.items()
}
return output

def get_tokenizer_class(self):


+ 9
- 15
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -2,12 +2,12 @@
from typing import Any, Dict

import torch
from PIL import Image
import unicodedata2
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from zhconv import convert

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

@@ -73,21 +73,14 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
"""
super(OfaOcrRecognitionPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
if self.cfg.model.imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose([
lambda image: ocr_resize(
image,
self.cfg.model.patch_image_size,
is_document=self.cfg.model.is_document),
self.patch_image_size,
is_document=self.cfg.model.get('is_document', False)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -98,8 +91,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target = sample['label']
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
@@ -119,5 +111,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
return sample

+ 18
- 0
modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py View File

@@ -0,0 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch

from modelscope.metainfo import Hooks
from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer
from .builder import HOOKS
from .hook import Hook


@HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook)
class ClipClampLogitScaleHook(Hook):
"""ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update"""

def after_train_iter(self, trainer: CLIPTrainer):
"""Called after every training iter to evaluate the results."""
unwrapped_model = getattr(trainer.model, 'module', trainer.model)
logit_scale = unwrapped_model.clip_model.logit_scale
logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052)

+ 1
- 1
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -61,7 +61,7 @@ class TextLoggerHook(LoggerHook):
self.json_log_path = osp.join(self.out_dir,
'{}.log.json'.format(trainer.timestamp))
if hasattr(trainer, 'meta') and trainer.meta is not None:
self._dump_log(trainer.meta, trainer)
self._dump_log(trainer.meta)

def _get_max_memory(self, trainer):
device = getattr(trainer.model, 'output_device', None)


+ 191
- 154
modelscope/trainers/multi_modal/clip/clip_trainer.py View File

@@ -1,169 +1,206 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math
import os
from typing import Dict, Optional
from typing import Callable, Dict, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch import distributed as dist
from torch import nn
from torch.utils.data import Dataset

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.trainers.base import BaseTrainer
from modelscope.models.base import Model, TorchModel
from modelscope.models.multi_modal.clip.model import convert_models_to_fp32
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.multi_modal import CLIPPreprocessor
from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModeKeys
from modelscope.utils.logger import get_logger
from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys,
ModeKeys)
from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule

logger = get_logger()

def exclude(n):
return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n


def include(n):
return not exclude(n)


@TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding)
class CLIPTrainer(BaseTrainer):

def __init__(self, cfg_file: str, model: str, device_id: int, *args,
**kwargs):
super().__init__(cfg_file)

self.cfg = Config.from_file(cfg_file)
self.model = Model.from_pretrained(model)
self.device_id = device_id
self.total_epoch = self.cfg.train.epoch
self.train_batch_size = self.cfg.train.batch_size
self.val_batch_size = self.cfg.evaluation.batch_size
self.ckpt_dir = self.cfg.train.ckpt_dir

self.train_dataset = ImageWithCaptionDataset(
json_file='{}/{}'.format(self.cfg.dataset.root_dir,
self.cfg.dataset.train_set),
img_dir=self.cfg.dataset.root_dir,
phase=ModeKeys.TRAIN)
self.val_dataset = ImageWithCaptionDataset(
json_file='{}/{}'.format(self.cfg.dataset.root_dir,
self.cfg.dataset.val_set),
img_dir=self.cfg.dataset.root_dir,
phase=ModeKeys.EVAL)

def train(self, *args, **kwargs):
assert dist.is_initialized()

self.model.clip_model.train()
self.model.clip_model.to(self.device_id)
ddp_model = torch.nn.parallel.DistributedDataParallel(
self.model.clip_model, device_ids=[
self.device_id,
])

optimizer = get_optimizer(ddp_model)

for epoch in range(self.total_epoch):
train_sampler = DistributedSampler(
dataset=self.train_dataset, shuffle=True)
train_sampler.set_epoch(epoch)

train_params = {
'pin_memory': True,
'collate_fn': None,
'batch_size': self.train_batch_size,
'shuffle': False,
'drop_last': True,
'sampler': train_sampler,
'num_workers': 8
class CLIPTrainer(EpochBasedTrainer):

def __init__(
self,
model: Optional[Union[TorchModel, nn.Module, str]] = None,
cfg_file: Optional[str] = None,
arg_parse_fn: Optional[Callable] = None,
data_collator: Optional[Union[Callable, Dict[str,
Callable]]] = None,
train_dataset: Optional[Union[MsDataset, Dataset]] = None,
eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
preprocessor: Optional[Union[Preprocessor,
Dict[str, Preprocessor]]] = None,
optimizers: Tuple[torch.optim.Optimizer,
torch.optim.lr_scheduler._LRScheduler] = (None,
None),
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
seed: int = 42,
**kwargs):
model = Model.from_pretrained(model, revision=model_revision)
# for training & eval, we convert the model from FP16 back to FP32
# to compatible with modelscope amp training
convert_models_to_fp32(model)
cfg = Config.from_file(cfg_file)
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
work_dir = cfg.train.work_dir
else:
work_dir = kwargs['work_dir']

# fetch the model name of CLIP model (base, large or large-336)
model_name = cfg.pretrained_model.model_name

# world size
world_size = int(os.environ.get('WORLD_SIZE', 1))

# train step, optimizer and lr_scheduler
epoch_steps = math.ceil(
len(train_dataset) / # noqa
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs

if optimizers[0] is None:
named_parameters = list(model.named_parameters())
gain_or_bias_params = [
p for n, p in named_parameters
if exclude(n) and p.requires_grad
]
rest_params = [
p for n, p in named_parameters
if include(n) and p.requires_grad
]
optimizer_hparams = get_optimizer_params(
model_name, cfg) # lr, wd, beta1, beta2, eps
optimizer_args = {
'params': [
{
'params': gain_or_bias_params,
'weight_decay': 0.
},
{
'params': rest_params,
'weight_decay': optimizer_hparams['weight_decay']
},
],
'lr':
optimizer_hparams['lr'],
'betas':
(optimizer_hparams['beta1'], optimizer_hparams['beta2']),
'eps':
optimizer_hparams['eps'],
}
optimizer = build_optimizer(
model, cfg=cfg.train.optimizer, default_args=optimizer_args)
else:
optimizer = optimizers[0]

if optimizers[1] is None:
lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler)
else:
lr_scheduler = optimizers[1]
optimizers = (optimizer, lr_scheduler)

# loss module
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0)))
self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0)))
self.loss_cfg = cfg.train.loss_cfg

# launcher and use_fp16
if 'launcher' not in kwargs and cfg.train.get('launcher', None):
kwargs['launcher'] = cfg.train.launcher
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False):
kwargs['use_fp16'] = cfg.train.use_fp16

# preprocessor
if preprocessor is None:
preprocessor = {
ConfigKeys.train:
CLIPPreprocessor(
model_dir=work_dir,
mode=ModeKeys.TRAIN,
tokenizer=model.tokenizer,
resolution=model.model_info['image_resolution']),
ConfigKeys.val:
CLIPPreprocessor(
model_dir=work_dir,
mode=ModeKeys.EVAL,
tokenizer=model.tokenizer,
resolution=model.model_info['image_resolution']),
}

train_loader = DataLoader(self.train_dataset, **train_params)

for batch_idx, (img_tensor, text_str_list,
img_id_list) in enumerate(train_loader):
text_info_list = [
self.model.tokenize_text(tmp) for tmp in text_str_list
]
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list],
dim=0)
text_masks_tensor = torch.cat(
[tmp[1] for tmp in text_info_list], dim=0)

img_tensor = img_tensor.to(self.device_id, non_blocking=True)
img_id_list = img_id_list.to(self.device_id, non_blocking=True)
text_ids_tensor = text_ids_tensor.to(
self.device_id, non_blocking=True)
text_masks_tensor = text_masks_tensor.to(
self.device_id, non_blocking=True)

loss = ddp_model((img_tensor, text_ids_tensor,
text_masks_tensor, img_id_list),
ModeKeys.TRAIN)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if batch_idx % 10 == 0:
logger.info(
'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}'
.format(epoch, batch_idx, len(train_loader),
loss.item(),
ddp_model.module.logit_scale.exp().item()))
if dist.get_rank() == 0:
os.makedirs(self.ckpt_dir, exist_ok=True)
torch.save(ddp_model.module.state_dict(),
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch))

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
if checkpoint_path is not None:
checkpoint_params = torch.load(checkpoint_path, 'cpu')
self.model.clip_model.load_state_dict(checkpoint_params)
self.model.clip_model.eval()
self.model.clip_model.to(self.device_id)

val_params = {
'collate_fn': None,
'batch_size': self.val_batch_size,
'shuffle': False,
'drop_last': False,
'num_workers': 8
}
val_loader = DataLoader(self.val_dataset, **val_params)

tp_cnt_per_batch = []
processed_cnt = 0
with torch.no_grad():
for batch_idx, (img_tensor, text_str_list,
img_id_list) in enumerate(val_loader):
text_info_list = [
self.model.tokenize_text(tmp) for tmp in text_str_list
]
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list],
dim=0)
text_masks_tensor = torch.cat(
[tmp[1] for tmp in text_info_list], dim=0)

img_tensor = img_tensor.to(self.device_id, non_blocking=True)
img_id_list = img_id_list.to(self.device_id, non_blocking=True)
text_ids_tensor = text_ids_tensor.to(
self.device_id, non_blocking=True)
text_masks_tensor = text_masks_tensor.to(
self.device_id, non_blocking=True)

img_feat = self.model.clip_model(img_tensor, input_type='img')
text_feat = self.model.clip_model(
(text_ids_tensor, text_masks_tensor), input_type='text')

sim_mat = text_feat @ img_feat.t()
text_cnt, img_cnt = sim_mat.shape
top1_scores, match_ids = torch.max(sim_mat, dim=1)

match_ids = match_ids.int()
gt_ids = torch.tensor(range(0, text_cnt)).to(
self.device_id, non_blocking=True).int()
error_cnt = torch.nonzero(match_ids - gt_ids)
processed_cnt += text_cnt

tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel())
logger.info('current acc: {:.3f}'.format(
sum(tp_cnt_per_batch) / processed_cnt))
# dataset related
self.dataset_cfg = cfg.dataset
if hasattr(self.dataset_cfg, 'column_map'):
# cases where dataset key names are not "img" and "text"
img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img')
preprocessor[ConfigKeys.train].set_input_img_key(img_key_name)
preprocessor[ConfigKeys.val].set_input_img_key(img_key_name)
text_key_name = getattr(self.dataset_cfg.column_map, 'text',
'text')
preprocessor[ConfigKeys.train].set_input_text_key(text_key_name)
preprocessor[ConfigKeys.val].set_input_text_key(text_key_name)
self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size

super().__init__(
model=model,
cfg_file=cfg_file,
arg_parse_fn=arg_parse_fn,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
preprocessor=preprocessor,
optimizers=optimizers,
seed=seed,
**kwargs,
)

def train_step(self, model, inputs):
model.train()
inputs['mode'] = ModeKeys.TRAIN
model_outputs = model.forward(
inputs
) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)}
loss = get_loss(model_outputs, self.loss_img, self.loss_txt,
self.loss_cfg)
train_outputs = {'loss': loss}
# add model output info to log
if 'log_vars' not in train_outputs:
default_keys_pattern = ['loss']
match_keys = set([])
for key_p in default_keys_pattern:
match_keys.update(
[key for key in train_outputs.keys() if key_p in key])
log_vars = {}
for key in match_keys:
value = train_outputs.get(key, None)
if value is not None:
if dist.is_available() and dist.is_initialized():
value = value.data.clone()
dist.all_reduce(value.div_(dist.get_world_size()))
log_vars.update({key: value.item()})
unwrapped_model = getattr(model, 'module', model)
log_vars[
'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone(
).item() # noqa
log_vars['global_batch_size'] = int(self.global_batch_size)
self.log_buffer.update(log_vars)
else:
self.log_buffer.update(train_outputs['log_vars'])
self.train_outputs = train_outputs

+ 121
- 90
modelscope/trainers/multi_modal/clip/clip_trainer_utils.py View File

@@ -1,94 +1,125 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import math
import os
import random
from functools import partial
from inspect import unwrap

import json
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from modelscope.utils.constant import ModeKeys

train_transform = transforms.Compose([
transforms.RandomResizedCrop(
224, scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
])

val_transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
])


class ImageWithCaptionDataset(Dataset):

def __init__(self, json_file, img_dir, phase):
self.annotations = json.load(open(json_file))
self.img_dir = img_dir
if phase == ModeKeys.TRAIN:
self.transform = train_transform
elif phase == ModeKeys.EVAL:
self.transform = val_transform

self.img_name2img_id = {}
for anno_dict in self.annotations:
img_name = anno_dict['image']
if img_name not in self.img_name2img_id:
self.img_name2img_id[img_name] = len(self.img_name2img_id)

def __len__(self):
return len(self.annotations)

def __getitem__(self, index):
anno_dict = self.annotations[index]

img_path = os.path.join(self.img_dir, anno_dict['image'])
img_pil = Image.open(img_path).convert('RGB')
img_th = self.transform(img_pil)
img_id = self.img_name2img_id[anno_dict['image']]

text_str = random.choice(anno_dict['caption'])

return img_th, text_str, img_id


def get_params_groups(ddp_model, weight_decay):
decay = []
no_decay = []
for name, param in ddp_model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith('.bias'):
no_decay.append(param)
else:
decay.append(param)
params_groups = [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
return params_groups


def get_optimizer(ddp_model):
from torch.optim import AdamW
lr_init = 1e-5
betas = [0.9, 0.999]
weight_decay = 0.02
params_groups = get_params_groups(ddp_model, weight_decay=weight_decay)
return AdamW(
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay)
import torch.distributed as dist
from torch.optim.lr_scheduler import LambdaLR

from modelscope.outputs import OutputKeys


def get_optimizer_params(model_name, cfg):
# get default params
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
# base model
if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']:
params = {
'lr': 5.0e-4,
'beta1': 0.9,
'beta2': 0.98,
'eps': 1.0e-6,
'weight_decay': 0.0
}
# large models
elif model_name in [
'damo/multi-modal_clip-vit-large-patch14_zh',
'damo/multi-modal_clip-vit-large-patch14_336_zh'
]:
params = {
'lr': 4.0e-4,
'beta1': 0.9,
'beta2': 0.98,
'eps': 1.0e-6,
'weight_decay': 0.0
}
else:
params = {
'lr': 5.0e-4,
'beta1': 0.9,
'beta2': 0.999,
'eps': 1.0e-8,
'weight_decay': 0.0
}
# override with config params
for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']:
if hasattr(cfg.train, 'optimizer_hparams'):
params[key] = getattr(cfg.train.optimizer_hparams, key,
params[key])
return params


def get_loss(model_outputs, loss_img, loss_txt, loss_cfg):
image_features = model_outputs[OutputKeys.IMG_EMBEDDING]
text_features = model_outputs[OutputKeys.TEXT_EMBEDDING]
logit_scale = model_outputs['logit_scale']
logit_scale = logit_scale.mean()
if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1:
world_size = dist.get_world_size()
rank = dist.get_rank()

# We gather tensors from all gpus to get more negatives to contrast with.
gathered_image_features = [
torch.zeros_like(image_features) for _ in range(world_size)
]
gathered_text_features = [
torch.zeros_like(text_features) for _ in range(world_size)
]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)

all_image_features = torch.cat([image_features]
+ gathered_image_features[:rank]
+ gathered_image_features[rank + 1:])
all_text_features = torch.cat([text_features]
+ gathered_text_features[:rank]
+ gathered_text_features[rank + 1:])

# this is needed to send gradients back everywhere.
logits_per_image = logit_scale * all_image_features @ all_text_features.t(
)
logits_per_text = logits_per_image.t()

else:
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()

ground_truth = torch.arange(len(logits_per_image)).long()
ground_truth = ground_truth.cuda(
int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True)

total_loss = (loss_img(logits_per_image, ground_truth)
+ loss_txt(logits_per_text, ground_truth)) / 2

return total_loss


def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps))
return max(
0.0,
0.5 * # noqa
(1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa


def get_schedule(optimizer,
scheduler,
num_cycles: float = 0.5,
last_epoch: int = -1):
num_warmup_steps = int(scheduler.warmup_proportion
* scheduler.num_train_steps)
num_training_steps = scheduler.num_train_steps

return LambdaLR(
optimizer,
partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles),
last_epoch)

+ 11
- 11
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):

def __init__(self, args):
super().__init__()
self.sentence_avg = args.sentence_avg
self.eps = args.label_smoothing
self.ignore_prefix_size = args.ignore_prefix_size
self.ignore_eos = args.ignore_eos
self.report_accuracy = args.report_accuracy
self.drop_worst_ratio = args.drop_worst_ratio
self.drop_worst_after = args.drop_worst_after
self.use_rdrop = args.use_rdrop
self.reg_alpha = args.reg_alpha
self.sample_patch_num = args.sample_patch_num
self.sentence_avg = args.get('sentence_avg', False)
self.eps = args.get('label_smoothing', 0.1)
self.ignore_prefix_size = args.get('ignore_prefix_size', 0)
self.ignore_eos = args.get('ignore_eos', False)
self.report_accuracy = args.get('report_accuracy', False)
self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0)
self.drop_worst_after = args.get('drop_worst_after', 0)
self.use_rdrop = args.get('use_rdrop', False)
self.reg_alpha = args.get('reg_alpha', 1.0)
self.sample_patch_num = args.get('sample_patch_num', 196)

self.constraint_start = None
self.constraint_end = None
if args.constraint_range:
if args.get('constraint_range', None):
constraint_start, constraint_end = args.constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)


+ 1
- 1
modelscope/trainers/nlp/text_generation_trainer.py View File

@@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer):
return tokenizer.decode(tokens.tolist(), skip_special_tokens=True)

def evaluation_step(self, data):
model = self.model
model = self.model.module if self._dist else self.model
model.eval()

with torch.no_grad():


+ 4
- 2
modelscope/trainers/nlp_trainer.py View File

@@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer):
preprocessor_mode=ModeKeys.TRAIN,
**model_args,
**self.train_keys,
mode=ModeKeys.TRAIN)
mode=ModeKeys.TRAIN,
use_fast=True)
eval_preprocessor = Preprocessor.from_pretrained(
self.model_dir,
cfg_dict=self.cfg,
preprocessor_mode=ModeKeys.EVAL,
**model_args,
**self.eval_keys,
mode=ModeKeys.EVAL)
mode=ModeKeys.EVAL,
use_fast=True)
return train_preprocessor, eval_preprocessor




+ 9
- 4
modelscope/trainers/trainer.py View File

@@ -15,6 +15,7 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.utils.utils import create_library_statistics
from modelscope.metainfo import Trainers
from modelscope.metrics import build_metric, task_default_metrics
from modelscope.models.base import Model, TorchModel
@@ -183,7 +184,7 @@ class EpochBasedTrainer(BaseTrainer):
preprocessor=self.eval_preprocessor,
**kwargs)

self.train_data_collator, self.eval_default_collate = None, None
self.train_data_collator, self.eval_data_collator = None, None
if isinstance(data_collator, Mapping):
if not (ConfigKeys.train in data_collator
or ConfigKeys.val in data_collator):
@@ -436,6 +437,8 @@ class EpochBasedTrainer(BaseTrainer):

def train(self, checkpoint_path=None, *args, **kwargs):
self._mode = ModeKeys.TRAIN
if hasattr(self.model, 'name'):
create_library_statistics('train', self.model.name, None)

if self.train_dataset is None:
self.train_dataloader = self.get_train_dataloader()
@@ -456,6 +459,8 @@ class EpochBasedTrainer(BaseTrainer):
self.train_loop(self.train_dataloader)

def evaluate(self, checkpoint_path=None):
if hasattr(self.model, 'name'):
create_library_statistics('evaluate', self.model.name, None)
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook
CheckpointHook.load_checkpoint(checkpoint_path, self)
@@ -672,7 +677,7 @@ class EpochBasedTrainer(BaseTrainer):
self.model, cfg=cfg, default_args=default_args)
except KeyError as e:
self.logger.error(
f'Build optimizer error, the optimizer {cfg} is native torch optimizer, '
f'Build optimizer error, the optimizer {cfg} is a torch native component, '
f'please check if your torch with version: {torch.__version__} matches the config.'
)
raise e
@@ -682,7 +687,7 @@ class EpochBasedTrainer(BaseTrainer):
return build_lr_scheduler(cfg=cfg, default_args=default_args)
except KeyError as e:
self.logger.error(
f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, '
f'Build lr_scheduler error, the lr_scheduler {cfg} is a torch native component, '
f'please check if your torch with version: {torch.__version__} matches the config.'
)
raise e
@@ -876,7 +881,7 @@ class EpochBasedTrainer(BaseTrainer):
Subclass and override to inject custom behavior.

"""
model = self.model
model = self.model.module if self._dist else self.model
model.eval()

if is_parallel(model):


+ 8
- 0
modelscope/utils/constant.py View File

@@ -238,6 +238,14 @@ class DownloadMode(enum.Enum):
FORCE_REDOWNLOAD = 'force_redownload'


class DownloadChannel(enum.Enum):
""" Channels of datasets downloading for uv/pv counting.
"""
LOCAL = 'local'
DSW = 'dsw'
EAIS = 'eais'


class UploadMode(enum.Enum):
""" How to upload object to remote.
"""


+ 65
- 0
modelscope/utils/cv/image_utils.py View File

@@ -91,6 +91,71 @@ def draw_keypoints(output, original_image):
return image


def draw_106face_keypoints(in_path,
keypoints,
boxes,
scale=4.0,
save_path=None):
face_contour_point_index = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
]
left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33]
right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42]
left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66]
right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75]
nose_bridge_point_index = [51, 52, 53, 54]
nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
mouth_outer_point_index = [
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84
]
mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96]

img = cv2.imread(in_path)

for i in range(len(boxes)):
draw_box(img, np.array(boxes[i]))

image = cv2.resize(img, dsize=None, fx=scale, fy=scale)

def draw_line(point_index, image, point):
for i in range(len(point_index) - 1):
cur_index = point_index[i]
next_index = point_index[i + 1]
cur_pt = (int(point[cur_index][0] * scale),
int(point[cur_index][1] * scale))
next_pt = (int(point[next_index][0] * scale),
int(point[next_index][1] * scale))
cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2)

for i in range(len(keypoints)):
points = keypoints[i]

draw_line(face_contour_point_index, image, points)
draw_line(left_eye_brow_point_index, image, points)
draw_line(right_eye_brow_point_index, image, points)
draw_line(left_eye_point_index, image, points)
draw_line(right_eye_point_index, image, points)
draw_line(nose_bridge_point_index, image, points)
draw_line(nose_contour_point_index, image, points)
draw_line(mouth_outer_point_index, image, points)
draw_line(mouth_inter_point_index, image, points)

size = len(points)
for i in range(size):
x = int(points[i][0])
y = int(points[i][1])
cv2.putText(image, str(i), (int(x * scale), int(y * scale)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0),
cv2.FILLED)

if save_path is not None:
cv2.imwrite(save_path, image)

return image


def draw_face_detection_no_lm_result(img_path, detection_result):
bboxes = np.array(detection_result[OutputKeys.BOXES])
scores = np.array(detection_result[OutputKeys.SCORES])


+ 1
- 16
modelscope/utils/demo_utils.py View File

@@ -4,11 +4,11 @@ import io

import cv2
import json
import numpy as np

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks, TasksIODescriptions
from modelscope.utils.service_utils import NumpyEncoder

TASKS_INPUT_TEMPLATES = {
# vision tasks
@@ -234,21 +234,6 @@ class DemoCompatibilityCheck(object):
return True


class NumpyEncoder(json.JSONEncoder):

def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()

if isinstance(obj, np.floating):
return float(obj)

if isinstance(obj, np.integer):
return int(obj)

return json.JSONEncoder.default(self, obj)


def preprocess(req):
in_urls = req.get('urlPaths').get('inUrls')
if len(req['inputs']) == 1:


+ 2
- 13
modelscope/utils/regress_test_utils.py View File

@@ -19,6 +19,8 @@ import torch
import torch.optim
from torch import nn

from modelscope.utils.service_utils import NumpyEncoder


class RegressTool:
"""This class is used to stop inference/training results from changing by some unaware affections by unittests.
@@ -117,19 +119,6 @@ class RegressTool:
with open(baseline, 'rb') as f:
base = pickle.load(f)

class NumpyEncoder(json.JSONEncoder):
"""Special json encoder for numpy types
"""

def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}')
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}')
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):


+ 179
- 0
modelscope/utils/service_utils.py View File

@@ -0,0 +1,179 @@
import base64
import mimetypes
from io import BytesIO

import json
import numpy as np
import requests
from PIL import Image

from modelscope.outputs import TASK_OUTPUTS, OutputKeys
from modelscope.pipeline_inputs import TASK_INPUTS, InputType
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks, TasksIODescriptions


# service data decoder func decodes data from network and convert it to pipeline's input
# for example
def ExampleDecoder(data):
# Assuming the pipeline inputs is a dict contains an image and a text,
# to decode the data from network we decode the image as base64
data_json = json.loads(data)
# data: {"image": "xxxxxxxx=="(base64 str), "text": "a question"}
# pipeline(inputs) as follows:
# pipeline({'image': image, 'text': text})
inputs = {
'image': decode_base64_to_image(data_json.get('image')),
'text': data_json.get('text')
}
return inputs


# service data encoder func encodes data from pipeline outputs and convert to network response (such as json)
# for example
def ExampleEncoder(data):
# Assuming the pipeline outputs is a dict contains an image and a text,
# and transmit it through network, this func encode image to base64 and dumps into json
# data (for e.g. python dict):
# {"image": a numpy array represents a image, "text": "output"}
image = data['image']
text = data['text']
data = {'image': encode_array_to_img_base64(image), 'text': text}
return json.dumps(data, cls=NumpyEncoder)


CustomEncoder = {
# Tasks.visual_question_answering: ExampleEncoder
}

CustomDecoder = {
# Tasks.visual_question_answering: ExampleDecoder
}


class NumpyEncoder(json.JSONEncoder):

def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()

if isinstance(obj, np.floating):
return float(obj)

if isinstance(obj, np.integer):
return int(obj)

return json.JSONEncoder.default(self, obj)


def get_extension(encoding):
encoding = encoding.replace('audio/wav', 'audio/x-wav')
tp = mimetypes.guess_type(encoding)[0]
if tp == 'audio/flac': # flac is not supported by mimetypes
return 'flac'
extension = mimetypes.guess_extension(tp)
if extension is not None and extension.startswith('.'):
extension = extension[1:]
return extension


def get_mimetype(filename):
mimetype = mimetypes.guess_type(filename)[0]
if mimetype is not None:
mimetype = mimetype.replace('x-wav', 'wav').replace('x-flac', 'flac')
return mimetype


def decode_base64_to_binary(encoding):
extension = get_extension(encoding)
data = encoding.split(',')[1]
return base64.b64decode(data), extension


def decode_base64_to_image(encoding):
content = encoding.split(';')[1]
image_encoded = content.split(',')[1]
return Image.open(BytesIO(base64.b64decode(image_encoded)))


def encode_array_to_img_base64(image_array):
with BytesIO() as output_bytes:
pil_image = Image.fromarray(image_array.astype(np.uint8))
pil_image.save(output_bytes, 'PNG')
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
return 'data:image/png;base64,' + base64_str


def encode_pcm_to_base64(bytes_data):
from scipy.io.wavfile import write
with BytesIO() as out_mem_file:
write(out_mem_file, 16000, bytes_data)
base64_str = str(base64.b64encode(out_mem_file.getvalue()), 'utf-8')
return 'data:audio/pcm;base64,' + base64_str


def encode_url_to_base64(url):
encoded_string = base64.b64encode(requests.get(url).content)
base64_str = str(encoded_string, 'utf-8')
mimetype = get_mimetype(url)
return ('data:' + (mimetype if mimetype is not None else '') + ';base64,'
+ base64_str)


def encode_file_to_base64(f):
with open(f, 'rb') as file:
encoded_string = base64.b64encode(file.read())
base64_str = str(encoded_string, 'utf-8')
mimetype = get_mimetype(f)
return ('data:' + (mimetype if mimetype is not None else '')
+ ';base64,' + base64_str)


def encode_url_or_file_to_base64(path):
try:
requests.get(path)
return encode_url_to_base64(path)
except (requests.exceptions.MissingSchema,
requests.exceptions.InvalidSchema):
return encode_file_to_base64(path)


def service_data_decoder(task, data):
if CustomDecoder.get(task) is not None:
return CustomDecoder[task](data)
input_type = TASK_INPUTS[task]
input_data = data.decode('utf-8')
if input_type == InputType.IMAGE:
return decode_base64_to_image(input_data)
elif input_type == InputType.AUDIO:
return decode_base64_to_binary(input_data)[0]
elif input_type == InputType.TEXT:
return input_data
elif isinstance(input_type, dict):
input_data = {}
for key, val in input_type.items():
if val == InputType.IMAGE:
input_data[key] = decode_base64_to_image(data[key])
elif val == InputType.AUDIO:
input_data[key] = decode_base64_to_binary(data[key])[0]
elif val == InputType.TEXT:
input_data[key] = data[key]

return input_data


def service_data_encoder(task, data):
if CustomEncoder.get(task) is not None:
return CustomEncoder[task](data)
output_keys = TASK_OUTPUTS[task]
result = data
for output_key in output_keys:
if output_key == OutputKeys.OUTPUT_IMG:
result[OutputKeys.OUTPUT_IMG] = encode_array_to_img_base64(
data[OutputKeys.OUTPUT_IMG][..., ::-1])
elif output_key == OutputKeys.OUTPUT_PCM:
result[OutputKeys.OUTPUT_PCM] = encode_pcm_to_base64(
data[OutputKeys.OUTPUT_PCM])
result = bytes(json.dumps(result, cls=NumpyEncoder), encoding='utf8')
return result

+ 1
- 1
modelscope/version.py View File

@@ -1,5 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
__version__ = '0.5.0'
__version__ = '1.0.0'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-10-13 08:56:12'

+ 2
- 1
requirements/framework.txt View File

@@ -1,6 +1,7 @@
addict
attrs
datasets
# version beyond 2.5.2 introduces compatbility issue and is being resolved
datasets<=2.5.2
easydict
einops
filelock>=3.3.0


+ 4
- 0
requirements/multi-modal.txt View File

@@ -2,6 +2,8 @@ ftfy>=6.0.3
ofa>=0.0.2
pycocoevalcap>=1.2
pycocotools>=2.0.4
# compatible with taming-transformers-rom1504
pytorch_lightning<=1.7.7
# rough-score was just recently updated from 0.0.4 to 0.0.7
# which introduced compatability issues that are being investigated
rouge_score<=0.0.4
@@ -11,3 +13,5 @@ timm
tokenizers
torchvision
transformers>=4.12.0
unicodedata2
zhconv

+ 0
- 1
requirements/nlp.txt View File

@@ -1,6 +1,5 @@
boto3
en_core_web_sm>=2.3.5
fasttext
filelock
ftfy
jieba>=0.42.1


+ 2
- 0
requirements/science.txt View File

@@ -1,4 +1,6 @@
biopython
iopath
ipdb
lmdb
ml_collections
scipy


+ 1
- 1
tests/export/test_export_sbert_sequence_classification.py View File

@@ -23,7 +23,7 @@ class TestExportSbertSequenceClassification(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skip
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_sbert_sequence_classification(self):
model = Model.from_pretrained(self.model_id)
print(


+ 82
- 2
tests/hub/test_hub_revision_release_mode.py View File

@@ -115,7 +115,7 @@ class HubRevisionTest(unittest.TestCase):
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Secnod time: %s' % t2)
logger.info('Second time: %s' % t2)
# set
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
@@ -142,6 +142,43 @@ class HubRevisionTest(unittest.TestCase):
finally:
version.__release_datetime__ = release_datetime_backup

def test_snapshot_download_revision_user_set_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time: %s' % t1)
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Secnod time: %s' % t2)
# set
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
% version.__release_datetime__)
try:
logger.info('Setting __release_datetime__ to: %s' % t1)
version.__release_datetime__ = t1
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert not os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
finally:
version.__release_datetime__ = release_datetime_backup

def test_file_download_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
@@ -175,7 +212,6 @@ class HubRevisionTest(unittest.TestCase):
self.model_id,
download_model_file_name,
cache_dir=temp_cache_dir)
print('Downloaded file path: %s' % file_path)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id,
@@ -185,6 +221,50 @@ class HubRevisionTest(unittest.TestCase):
finally:
version.__release_datetime__ = release_datetime_backup

def test_file_download_revision_user_set_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time stamp: %s' % t1)
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Second time: %s' % t2)
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
% version.__release_datetime__)
try:
version.__release_datetime__ = t1
logger.info('Setting __release_datetime__ to: %s' % t1)
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=self.revision,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
with self.assertRaises(NotExistError):
model_file_download(
self.model_id,
download_model_file_name2,
revision=self.revision,
cache_dir=temp_cache_dir)
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id,
download_model_file_name2,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
finally:
version.__release_datetime__ = release_datetime_backup


if __name__ == '__main__':
unittest.main()

+ 6
- 2
tests/msdatasets/test_dataset_upload.py View File

@@ -8,7 +8,8 @@ import zipfile
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects
from modelscope.utils import logger as logging
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode,
ModelFile)
from modelscope.utils.test_utils import test_level

logger = logging.get_logger(__name__)
@@ -104,7 +105,10 @@ class DatasetUploadTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ds_download_dir(self):
test_ds = MsDataset.load(self.dataset_name, self.namespace)
test_ds = MsDataset.load(
self.dataset_name,
namespace=self.namespace,
download_mode=DownloadMode.FORCE_REDOWNLOAD)
assert test_ds.config_kwargs['split_config'].values()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


+ 2
- 1
tests/outputs/test_model_outputs.py View File

@@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase):
self.assertEqual(outputs['logits'], torch.Tensor([1]))
self.assertEqual(outputs[0], torch.Tensor([1]))
self.assertEqual(outputs.logits, torch.Tensor([1]))
outputs.loss = torch.Tensor([2])
logits, loss = outputs
self.assertEqual(logits, torch.Tensor([1]))
self.assertTrue(loss is None)
self.assertTrue(loss is not None)


if __name__ == '__main__':


+ 17
- 12
tests/pipelines/test_face_2d_keypoints.py View File

@@ -1,11 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

import cv2

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_106face_keypoints
from modelscope.utils.test_utils import test_level


@@ -13,7 +12,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_face_2d_keypoints(self):
img_path = 'data/test/images/keypoints_detect/test_img_face_2d_keypoints.png'
img_path = 'data/test/images/face_detection.png'
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment'

face_2d_keypoints_align = pipeline(
@@ -21,15 +20,21 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase):
output = face_2d_keypoints_align(img_path)

output_keypoints = output[OutputKeys.KEYPOINTS]
output_pose = output[OutputKeys.POSES]

img = cv2.imread(img_path)
img = face_2d_keypoints_align.show_result(
img, output_keypoints, scale=2, save_path='face_keypoints.jpg')

self.assertEqual(output_keypoints.shape[0], 106)
self.assertEqual(output_keypoints.shape[1], 2)
self.assertEqual(output_pose.shape[0], 3)
output_poses = output[OutputKeys.POSES]
output_boxes = output[OutputKeys.BOXES]

draw_106face_keypoints(
img_path,
output_keypoints,
output_boxes,
scale=2,
save_path='face_keypoints.jpg')

for idx in range(len(output_keypoints)):
self.assertEqual(output_keypoints[idx].shape[0], 106)
self.assertEqual(output_keypoints[idx].shape[1], 2)
self.assertEqual(output_poses[idx].shape[0], 3)
self.assertEqual(output_boxes[idx].shape[0], 4)


if __name__ == '__main__':


+ 2
- 2
tests/pipelines/test_face_emotion.py View File

@@ -17,12 +17,12 @@ class FaceEmotionTest(unittest.TestCase):
result = pipeline(input)
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip('skip since the model is set to private for now')
def test_run_modelhub(self):
face_emotion = pipeline(Tasks.face_emotion, model=self.model)
self.pipeline_inference(face_emotion, self.img)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skip('skip since the model is set to private for now')
def test_run_modelhub_default_model(self):
face_emotion = pipeline(Tasks.face_emotion)
self.pipeline_inference(face_emotion, self.img)


+ 1
- 1
tests/pipelines/test_facial_expression_recognition.py View File

@@ -23,7 +23,7 @@ class FacialExpressionRecognitionTest(unittest.TestCase):
cv2.imwrite('result.png', img)
print(f'output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip('skip since the model is set to private for now')
def test_run_modelhub(self):
fer = pipeline(
Tasks.facial_expression_recognition, model=self.model_id)


+ 3
- 3
tests/pipelines/test_multi_modal_embedding.py View File

@@ -24,7 +24,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck):
def test_run(self):
pipeline_multi_modal_embedding = pipeline(
Tasks.multi_modal_embedding, model=self.model_id)
text_embedding = pipeline_multi_modal_embedding(
text_embedding = pipeline_multi_modal_embedding.forward(
self.test_input)[OutputKeys.TEXT_EMBEDDING]
print('l1-norm: {}'.format(
torch.norm(text_embedding, p=1, dim=-1).item()))
@@ -36,7 +36,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck):
model = Model.from_pretrained(self.model_id)
pipeline_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding, model=model)
text_embedding = pipeline_multi_modal_embedding(
text_embedding = pipeline_multi_modal_embedding.forward(
self.test_input)[OutputKeys.TEXT_EMBEDDING]
print('l1-norm: {}'.format(
torch.norm(text_embedding, p=1, dim=-1).item()))
@@ -47,7 +47,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_with_default_model(self):
pipeline_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding)
text_embedding = pipeline_multi_modal_embedding(
text_embedding = pipeline_multi_modal_embedding.forward(
self.test_input)[OutputKeys.TEXT_EMBEDDING]
print('l1-norm: {}'.format(
torch.norm(text_embedding, p=1, dim=-1).item()))


+ 8
- 0
tests/pipelines/test_named_entity_recognition.py View File

@@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.named_entity_recognition
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'

english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom'
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
sentence = '这与温岭市新河镇的一个神秘的传说有关。'
sentence_en = 'pizza shovel'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download(self):
@@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.named_entity_recognition, model=self.lcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_english_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.english_model_id)
print(pipeline_ins(input='pizza shovel'))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.named_entity_recognition)


+ 1
- 1
tests/pipelines/test_unifold.py View File

@@ -19,7 +19,7 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck):
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \
'NIAALKNHIDKIKPIAMQIYKKYSKNIP'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id)
mono_pipeline_ins = pipeline(task=self.task, model=model_dir)


+ 2
- 1
tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py View File

@@ -50,7 +50,8 @@ class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase):
trainer = build_trainer(trainer_name, kwargs)
trainer.train()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip(
'skip since face_2d_keypoints_dataset is set to private for now')
def test_trainer_single_gpu(self):
temp_file_dir = tempfile.TemporaryDirectory()
tmp_dir = temp_file_dir.name


+ 83
- 0
tests/trainers/test_clip_trainer.py View File

@@ -0,0 +1,83 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest

import json

from modelscope.metainfo import Metrics, Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestClipTrainer(unittest.TestCase):

def setUp(self) -> None:
self.finetune_cfg = \
{'framework': 'pytorch',
'task': 'multi-modal-embedding',
'pipeline': {'type': 'multi-modal-embedding'},
'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'},
'dataset': {'column_map': {'img': 'image', 'text': 'query'}},
'train': {'work_dir': './workspace/ckpts/clip',
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 8,
'workers_per_gpu': 0,
'shuffle': True,
'drop_last': True},
'lr_scheduler': {'name': 'cosine',
'warmup_proportion': 0.01},
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
'optimizer': {'type': 'AdamW'},
'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01},
'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
'cumulative_iters': 1,
'loss_keys': 'loss'},
'loss_cfg': {'aggregate': True},
'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'inbatch_t2i_recall_at_1',
'interval': 100},
{'type': 'TextLoggerHook', 'interval': 1},
{'type': 'IterTimerHook'},
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1},
{'type': 'ClipClampLogitScaleHook'}]},
'evaluation': {'dataloader': {'batch_size_per_gpu': 8,
'workers_per_gpu': 0,
'shuffle': True,
'drop_last': True},
'metrics': [{'type': 'inbatch_recall'}]},
'preprocessor': []}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_std(self):
WORKSPACE = './workspace/ckpts/clip'
os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)

pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh'
args = dict(
model=pretrained_model,
work_dir=WORKSPACE,
train_dataset=MsDataset.load(
'muge', namespace='modelscope', split='train[:200]'),
eval_dataset=MsDataset.load(
'muge', namespace='modelscope', split='validation[:100]'),
metrics=[Metrics.inbatch_recall],
cfg_file=config_file)
trainer = build_trainer(
name=Trainers.clip_multi_modal_embedding, default_args=args)
trainer.train()

self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, 'output')))
shutil.rmtree(WORKSPACE)


if __name__ == '__main__':
unittest.main()

+ 1
- 1
tests/trainers/test_finetune_sequence_classification.py View File

@@ -38,7 +38,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skip
def test_trainer_cfg_class(self):
dataset = MsDataset.load('clue', subset_name='tnews')
train_dataset = dataset['train']


+ 1
- 1
tests/trainers/test_finetune_token_classificatin.py View File

@@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase):
cfg['dataset'] = {
'train': {
'labels': label_enumerate_values,
'first_sequence': 'first_sequence',
'first_sequence': 'tokens',
'label': 'labels',
}
}


+ 1
- 1
tests/trainers/test_ofa_trainer.py View File

@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase):
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:200]',
split='train[800:900]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'ocr_fudanvi_zh',


+ 1
- 1
tests/trainers/test_trainer_with_nlp.py View File

@@ -72,7 +72,7 @@ class TestTrainerWithNlp(unittest.TestCase):
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
pipeline_sentence_similarity(output_dir)

@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
@unittest.skip
def test_trainer_with_backbone_head(self):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
kwargs = dict(


Loading…
Cancel
Save