Browse Source

Merge branch 'master-gitlab' into master-github

master
wenmeng.zwm 2 years ago
parent
commit
266851bbbb
40 changed files with 623 additions and 187 deletions
  1. +2
    -0
      .dev_scripts/dockerci.sh
  2. +3
    -3
      README.md
  3. +2
    -0
      modelscope/metainfo.py
  4. +4
    -3
      modelscope/metrics/accuracy_metric.py
  5. +3
    -2
      modelscope/metrics/builder.py
  6. +67
    -0
      modelscope/metrics/map_metric.py
  7. +3
    -14
      modelscope/metrics/text_generation_metric.py
  8. +9
    -2
      modelscope/models/audio/kws/farfield/model.py
  9. +2
    -24
      modelscope/models/multi_modal/mplug_for_all_tasks.py
  10. +3
    -2
      modelscope/models/nlp/task_models/text_generation.py
  11. +1
    -0
      modelscope/pipelines/base.py
  12. +3
    -3
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  13. +10
    -8
      modelscope/pipelines/nlp/text_classification_pipeline.py
  14. +2
    -23
      modelscope/pipelines/nlp/text_generation_pipeline.py
  15. +9
    -8
      modelscope/preprocessors/base.py
  16. +12
    -6
      modelscope/preprocessors/nlp/token_classification_preprocessor.py
  17. +1
    -1
      modelscope/preprocessors/ofa/image_captioning.py
  18. +79
    -11
      modelscope/preprocessors/ofa/image_classification.py
  19. +2
    -5
      modelscope/preprocessors/ofa/ocr_recognition.py
  20. +34
    -2
      modelscope/preprocessors/ofa/summarization.py
  21. +54
    -5
      modelscope/preprocessors/ofa/visual_entailment.py
  22. +94
    -12
      modelscope/preprocessors/ofa/visual_grounding.py
  23. +47
    -3
      modelscope/preprocessors/ofa/visual_question_answering.py
  24. +5
    -2
      modelscope/trainers/audio/kws_farfield_trainer.py
  25. +5
    -1
      modelscope/trainers/multi_modal/__init__.py
  26. +3
    -0
      modelscope/trainers/multi_modal/mplug/__init__.py
  27. +40
    -0
      modelscope/trainers/multi_modal/mplug/mplug_trainer.py
  28. +13
    -4
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  29. +35
    -0
      modelscope/utils/chinese_utils.py
  30. +18
    -0
      modelscope/utils/regress_test_utils.py
  31. +7
    -3
      tests/pipelines/test_fill_mask.py
  32. +10
    -0
      tests/pipelines/test_multilingual_named_entity_recognition.py
  33. +9
    -1
      tests/pipelines/test_named_entity_recognition.py
  34. +4
    -3
      tests/pipelines/test_nli.py
  35. +4
    -2
      tests/pipelines/test_sentence_similarity.py
  36. +7
    -3
      tests/pipelines/test_word_segmentation.py
  37. +4
    -2
      tests/pipelines/test_zero_shot_classification.py
  38. +10
    -28
      tests/trainers/test_finetune_mplug.py
  39. +2
    -0
      tests/trainers/test_ofa_trainer.py
  40. +1
    -1
      tests/utils/test_ast.py

+ 2
- 0
.dev_scripts/dockerci.sh View File

@@ -37,6 +37,7 @@ do
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_ENVIRONMENT='ci' \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \
@@ -59,6 +60,7 @@ do
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_ENVIRONMENT='ci' \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \


+ 3
- 3
README.md View File

@@ -1,10 +1,10 @@
# Introduction

[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bringing together most advanced machine learning models from the AI community, and to streamlining the process of leveraging and applying AI models . The core ModelScope library enables developers to perform model inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains.
[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bring together most advanced machine learning models from the AI community, and to streamline the process of leveraging AI models in real applications. The core ModelScope library enables developers to perform inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains.

The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done within only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary.
The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done with only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary.

Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with the backend services of ModelScope, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate various entity (models and datasets) management to be performed seamlessly under-the-hood, such as entity lookup, version control, and cache management.
Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with ModelScope backend services, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate management of various entities (models and datasets) to be performed seamlessly under-the-hood, including entity lookup, version control, cache management, and many others.

# Installation



+ 2
- 0
modelscope/metainfo.py View File

@@ -299,6 +299,7 @@ class Trainers(object):
# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa = 'ofa'
mplug = 'mplug'

# cv trainers
image_instance_segmentation = 'image-instance-segmentation'
@@ -402,6 +403,7 @@ class Metrics(object):

# accuracy
accuracy = 'accuracy'
multi_average_precision = 'mAP'
audio_noise_metric = 'audio-noise-metric'

# text gen


+ 4
- 3
modelscope/metrics/accuracy_metric.py View File

@@ -6,6 +6,7 @@ import numpy as np

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@@ -26,10 +27,10 @@ class AccuracyMetric(Metric):
def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
eval_results = None
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
@@ -39,7 +40,7 @@ class AccuracyMetric(Metric):
self.labels.append(truth)
for result in eval_results:
if isinstance(truth, str):
self.preds.append(result.strip().replace(' ', ''))
self.preds.append(remove_space_between_chinese_chars(result))
else:
self.preds.append(result)



+ 3
- 2
modelscope/metrics/builder.py View File

@@ -24,6 +24,7 @@ class MetricKeys(object):
ROUGE_1 = 'rouge-1'
ROUGE_L = 'rouge-l'
NED = 'ned' # ocr metric
mAP = 'mAP'
BatchAcc = 'inbatch_t2i_recall_at_1'


@@ -40,8 +41,8 @@ task_default_metrics = {
Tasks.image_portrait_enhancement:
[Metrics.image_portrait_enhancement_metric],
Tasks.video_summarization: [Metrics.video_summarization_metric],
Tasks.image_captioning: [Metrics.text_gen_metric],
Tasks.visual_question_answering: [Metrics.text_gen_metric],
Tasks.image_captioning: [Metrics.accuracy],
Tasks.visual_question_answering: [Metrics.accuracy],
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric],
Tasks.image_inpainting: [Metrics.image_inpainting_metric],
Tasks.referring_video_object_segmentation:


+ 67
- 0
modelscope/metrics/map_metric.py View File

@@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group, module_name=Metrics.multi_average_precision)
class AveragePrecisionMetric(Metric):
"""The metric computation class for multi avarage precision classes.

This metric class calculates multi avarage precision for the whole input batches.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.preds = []
self.labels = []
self.thresh = kwargs.get('threshold', 0.5)

def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
for truth in ground_truths:
self.labels.append(truth)
for result in eval_results:
if isinstance(truth, str):
self.preds.append(result.strip().replace(' ', ''))
else:
self.preds.append(result)

def evaluate(self):
assert len(self.preds) == len(self.labels)
scores = self._calculate_ap_score(self.preds, self.labels, self.thresh)
return {MetricKeys.mAP: scores.mean().item()}

def _calculate_ap_score(self, preds, labels, thresh=0.5):
hyps = np.array(preds)
refs = np.array(labels)
a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2])
b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])
interacts = np.concatenate([a, b], axis=1)
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (
hyps[:, 3] - hyps[:, 1])
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
interacts_w = interacts[:, 2] - interacts[:, 0]
interacts_h = interacts[:, 3] - interacts[:, 1]
area_interacts = interacts_w * interacts_h
ious = area_interacts / (
area_predictions + area_targets - area_interacts + 1e-6)
return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)

+ 3
- 14
modelscope/metrics/text_generation_metric.py View File

@@ -8,6 +8,7 @@ from rouge import Rouge
from modelscope.metainfo import Metrics
from modelscope.metrics.base import Metric
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.utils.chinese_utils import rebuild_chinese_str
from modelscope.utils.registry import default_group


@@ -24,25 +25,13 @@ class TextGenerationMetric(Metric):
self.tgts: List[str] = []
self.rouge = Rouge()

@staticmethod
def is_chinese_char(char: str):
# the length of char must be 1
return '\u4e00' <= char <= '\u9fa5'

# add space for each chinese char
def rebuild_str(self, string: str):
return ' '.join(''.join([
f' {char} ' if self.is_chinese_char(char) else char
for char in string
]).split())

def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]):
ground_truths = inputs['tgts']
eval_results = outputs['preds']
for truth in ground_truths:
self.tgts.append(self.rebuild_str(truth))
self.tgts.append(rebuild_chinese_str(truth))
for result in eval_results:
self.preds.append(self.rebuild_str(result))
self.preds.append(rebuild_chinese_str(result))

def _check(self, pred: str, tgt: str) -> bool:



+ 9
- 2
modelscope/models/audio/kws/farfield/model.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import tempfile
from typing import Dict, Optional

from modelscope.metainfo import Models
@@ -36,12 +37,15 @@ class FSMNSeleNetV2Decorator(TorchModel):
else:
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
self.tmp_dir = tempfile.TemporaryDirectory()
new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG)

self._sc = None
if os.path.exists(model_txt_file):
conf_dict = dict(mode=56542, kws_model=model_txt_file)
update_conf(sc_config_file, sc_config_file, conf_dict)
update_conf(sc_config_file, new_config_file, conf_dict)
import py_sound_connect
self._sc = py_sound_connect.SoundConnect(sc_config_file)
self._sc = py_sound_connect.SoundConnect(new_config_file)
self.size_in = self._sc.bytesPerBlockIn()
self.size_out = self._sc.bytesPerBlockOut()
else:
@@ -49,6 +53,9 @@ class FSMNSeleNetV2Decorator(TorchModel):
f'Invalid model directory! Failed to load model file: {model_txt_file}.'
)

def __del__(self):
self.tmp_dir.cleanup()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.model.forward(input)



+ 2
- 24
modelscope/models/multi_modal/mplug_for_all_tasks.py View File

@@ -45,10 +45,6 @@ class MPlugForAllTasks(TorchModel):
}
"""

replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))

# get task from config file
task = Config.from_file(
osp.join(self.model_dir, ModelFile.CONFIGURATION)).task
@@ -60,10 +56,7 @@ class MPlugForAllTasks(TorchModel):
return {OutputKeys.SCORES: output[0].tolist()}
topk_ids, _ = output
pred_string: List[str] = \
self.tokenizer.decode(topk_ids[0][0])
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string = pred_string.strip()
self.tokenizer.decode(topk_ids[0][0], skip_special_tokens=True)
output_key = OutputKeys.CAPTION \
if task == Tasks.image_captioning else OutputKeys.TEXT
return {output_key: pred_string}
@@ -87,19 +80,4 @@ class MPlugForAllTasks(TorchModel):

# evaluate
topk_ids, _ = output
preds: List[str] = [
self.tokenizer.decode(batch[0]) for batch in topk_ids
]
for i in range(len(preds)):
for _old, _new in replace_tokens_bert:
preds[i] = preds[i].replace(_old, _new)
preds[i] = preds[i].strip()
tgts: List[str] = [
self.tokenizer.decode(batch)
for batch in input['answer_input_ids'].cpu().numpy().tolist()
]
for i in range(len(tgts)):
for _old, _new in replace_tokens_bert:
tgts[i] = tgts[i].replace(_old, _new)
preds[i] = preds[i].strip()
return {'preds': preds, 'tgts': tgts}
return {'sequences': [list_tensor[0] for list_tensor in topk_ids]}

+ 3
- 2
modelscope/models/nlp/task_models/text_generation.py View File

@@ -2,7 +2,7 @@
from typing import Any, Dict

import numpy as np
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import GenerationMixin

from modelscope.metainfo import TaskModels
from modelscope.models.builder import MODELS
@@ -17,7 +17,8 @@ __all__ = ['TaskModelForTextGeneration']

@MODELS.register_module(
Tasks.text_generation, module_name=TaskModels.text_generation)
class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel):
class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin):
main_input_name = 'input_ids'

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.


+ 1
- 0
modelscope/pipelines/base.py View File

@@ -366,6 +366,7 @@ class DistributedPipeline(Pipeline):
master_port=master_port,
**self.cfg.model,
**kwargs), ranks)
self.models = []

def __del__(self):
if hasattr(self, 'model_pool') and self.model_pool is not None:


+ 3
- 3
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -132,8 +132,8 @@ class Body3DKeypointsPipeline(Pipeline):
device='gpu' if torch.cuda.is_available() else 'cpu')

def preprocess(self, input: Input) -> Dict[str, Any]:
video_url = input
video_frames = self.read_video_frames(video_url)
self.video_url = input
video_frames = self.read_video_frames(self.video_url)
if 0 == len(video_frames):
res = {'success': False, 'msg': 'get video frame failed.'}
return res
@@ -198,7 +198,7 @@ class Body3DKeypointsPipeline(Pipeline):
}

if not input['success']:
pass
res[OutputKeys.OUTPUT_VIDEO] = self.video_url
else:
poses = input[KeypointsTypes.POSES_CAMERA]
pred_3d_pose = poses.data.cpu().numpy()[


+ 10
- 8
modelscope/pipelines/nlp/text_classification_pipeline.py View File

@@ -3,14 +3,13 @@ from typing import Any, Dict, Union

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models.base import Model
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Fields, Tasks


@PIPELINES.register_module(
@@ -58,8 +57,11 @@ class TextClassificationPipeline(Pipeline):
str) else model

if preprocessor is None:
if isinstance(model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=model.model_dir)
if model.__class__.__name__ == 'OfaForAllTasks':
preprocessor = Preprocessor.from_pretrained(
model_name_or_path=model.model_dir,
type=Preprocessors.ofa_tasks_preprocessor,
field=Fields.multi_modal)
else:
first_sequence = kwargs.pop('first_sequence', 'first_sequence')
second_sequence = kwargs.pop('second_sequence', None)
@@ -76,7 +78,7 @@ class TextClassificationPipeline(Pipeline):

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
if isinstance(self.model, OfaForAllTasks):
if self.model.__class__.__name__ == 'OfaForAllTasks':
return super().forward(inputs, **forward_params)
return self.model(**inputs, **forward_params)

@@ -95,7 +97,7 @@ class TextClassificationPipeline(Pipeline):
labels: The real labels.
Label at index 0 is the smallest probability.
"""
if isinstance(self.model, OfaForAllTasks):
if self.model.__class__.__name__ == 'OfaForAllTasks':
return inputs
else:
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \


+ 2
- 23
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -10,6 +10,7 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Preprocessor, build_preprocessor
from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
from modelscope.utils.constant import Fields, Tasks
from modelscope.utils.hub import read_config

@@ -78,28 +79,6 @@ class TextGenerationPipeline(Pipeline):
with torch.no_grad():
return self.model.generate(inputs, **forward_params)

def _is_chinese_char(self, word: str):
chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》')
return len(word) == 1 \
and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations)

def _remove_space_between_chinese_chars(self, decoded: str):
old_word_list = decoded.split(' ')
new_word_list = []
start = -1
for i, word in enumerate(old_word_list):
if self._is_chinese_char(word):
if start == -1:
start = i
else:
if start != -1:
new_word_list.append(''.join(old_word_list[start:i]))
start = -1
new_word_list.append(word)
if start != -1:
new_word_list.append(''.join(old_word_list[start:]))
return ' '.join(new_word_list)

def decode(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
return tokenizer.decode(inputs.tolist(), skip_special_tokens=True)
@@ -128,5 +107,5 @@ class TextGenerationPipeline(Pipeline):
if isinstance(inputs, list) or len(inputs.shape) > 1:
inputs = inputs[0]
decoded = getattr(self, self.postprocessor)(inputs)
text = self._remove_space_between_chinese_chars(decoded)
text = remove_space_between_chinese_chars(decoded)
return {OutputKeys.TEXT: text}

+ 9
- 8
modelscope/preprocessors/base.py View File

@@ -205,10 +205,12 @@ class Preprocessor(ABC):
if 'task' in kwargs:
task = kwargs.pop('task')
field_name = Tasks.find_field_by_task(task)
if 'field' in kwargs:
field_name = kwargs.pop('field')
sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val'

if not hasattr(cfg, 'preprocessor'):
logger.error('No preprocessor field found in cfg.')
if not hasattr(cfg, 'preprocessor') or len(cfg.preprocessor) == 0:
logger.warn('No preprocessor field found in cfg.')
preprocessor_cfg = ConfigDict()
else:
preprocessor_cfg = cfg.preprocessor
@@ -217,9 +219,8 @@ class Preprocessor(ABC):
if sub_key in preprocessor_cfg:
sub_cfg = getattr(preprocessor_cfg, sub_key)
else:
logger.error(
f'No {sub_key} key and type key found in '
f'preprocessor domain of configuration.json file.')
logger.warn(f'No {sub_key} key and type key found in '
f'preprocessor domain of configuration.json file.')
sub_cfg = preprocessor_cfg
else:
sub_cfg = preprocessor_cfg
@@ -235,7 +236,7 @@ class Preprocessor(ABC):

preprocessor = build_preprocessor(sub_cfg, field_name)
else:
logger.error(
logger.warn(
f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, '
f'current config: {sub_cfg}. trying to build by task and model information.'
)
@@ -243,13 +244,13 @@ class Preprocessor(ABC):
model_type = model_cfg.type if hasattr(
model_cfg, 'type') else getattr(model_cfg, 'model_type', None)
if task is None or model_type is None:
logger.error(
logger.warn(
f'Find task: {task}, model type: {model_type}. '
f'Insufficient information to build preprocessor, skip building preprocessor'
)
return None
if (model_type, task) not in PREPROCESSOR_MAP:
logger.error(
logger.warn(
f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, '
f'skip building preprocessor.')
return None


+ 12
- 6
modelscope/preprocessors/nlp/token_classification_preprocessor.py View File

@@ -73,10 +73,12 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
super().__init__(model_dir, mode=mode, **kwargs)

if 'is_split_into_words' in kwargs:
self.is_split_into_words = kwargs.pop('is_split_into_words')
self.tokenize_kwargs['is_split_into_words'] = kwargs.pop(
'is_split_into_words')
else:
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)
self.tokenize_kwargs[
'is_split_into_words'] = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)
if 'label2id' in kwargs:
kwargs.pop('label2id')

@@ -99,7 +101,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
if isinstance(data, str):
# for inference inputs without label
text = data
self.tokenize_kwargs['add_special_tokens'] = False
elif isinstance(data, dict):
# for finetune inputs with label
text = data.get(self.first_sequence)
@@ -107,11 +108,15 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
if isinstance(text, list):
self.tokenize_kwargs['is_split_into_words'] = True

if self._mode == ModeKeys.INFERENCE:
self.tokenize_kwargs['add_special_tokens'] = False

input_ids = []
label_mask = []
offset_mapping = []
token_type_ids = []
if self.is_split_into_words and self._mode == ModeKeys.INFERENCE:
if self.tokenize_kwargs[
'is_split_into_words'] and self._mode == ModeKeys.INFERENCE:
for offset, token in enumerate(list(text)):
subtoken_ids = self.tokenizer.encode(token,
**self.tokenize_kwargs)
@@ -125,7 +130,8 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
encodings = self.tokenizer(
text, return_offsets_mapping=True, **self.tokenize_kwargs)
attention_mask = encodings['attention_mask']
token_type_ids = encodings['token_type_ids']
if 'token_type_ids' in encodings:
token_type_ids = encodings['token_type_ids']
input_ids = encodings['input_ids']
word_ids = encodings.word_ids()
for i in range(len(word_ids)):


+ 1
- 1
modelscope/preprocessors/ofa/image_captioning.py View File

@@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = sample['label']
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])


+ 79
- 11
modelscope/preprocessors/ofa/image_classification.py View File

@@ -1,13 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
from typing import Any, Dict

import torch
from PIL import Image
from PIL import Image, ImageFile
from timm.data import create_transform
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils.vision_helper import RandomAugment

ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None


class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
@@ -28,18 +35,77 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
super(OfaImageClassificationPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
if self.mode != ModeKeys.TRAIN:
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
else:
self.patch_resize_transform = create_transform(
input_size=self.patch_image_size,
is_training=True,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=self.mean,
std=self.std)
self.patch_resize_transform = transforms.Compose(
functools.reduce(lambda x, y: x + y, [
[
lambda image: image.convert('RGB'),
],
self.patch_resize_transform.transforms[:2],
[self.patch_resize_transform.transforms[2]],
[
RandomAugment(
2,
7,
isPIL=True,
augs=[
'Identity', 'AutoContrast', 'Equalize',
'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Rotate'
]),
],
self.patch_resize_transform.transforms[3:],
]))

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(sample['label'])
sample['ref_dict'] = {sample['label']: 1.0}
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target'][:-1]])

if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(sample['prev_output_tokens']),
len(self.tgt_dict))).bool()
for i in range(len(sample['prev_output_tokens'])):
constraint_prefix_token = sample[
'prev_output_tokens'][:i + 1].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask

return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.tokenize_text(prompt)
@@ -48,4 +114,6 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
return sample

+ 2
- 5
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -11,9 +11,6 @@ from zhconv import convert
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def ocr_resize(img, patch_image_size, is_document=False):
img = img.convert('RGB')
@@ -112,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
}
if 'text' in self.column_map and self.column_map['text'] in data:
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
sample['label'] = unicodedata2.normalize(
'NFKC', convert(target, 'zh-hans'))
return sample

+ 34
- 2
modelscope/preprocessors/ofa/summarization.py View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import torch

from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

@@ -24,9 +26,26 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
self).__init__(cfg, model_dir, mode, *args, **kwargs)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target_str = sample['label'].lower()
target = super().pre_caption(target_str, max_words=self.max_tgt_length)
target = target.replace('[unk]', 'unk').replace('<unk>', 'unk')
sample['target'] = self.tokenize_text(target, add_bos=False)
noise_target_item = self.add_noise_to_tgt(
sample['target'][:-1].clone())
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, noise_target_item])
return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = super().pre_caption(
data['text'], max_words=self.max_src_length)
source = source.strip()[:self.max_src_length]
data[self.column_map['text']], max_words=self.max_src_length)
source = source.replace('[unk]', 'unk').replace('<unk>', 'unk')
prompt = self.cfg.model.get(
'prompt', ' " {} " Summarize the article with a title: ')
@@ -42,4 +61,17 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
'source': inputs,
'decoder_prompt': decoder_prompt,
}
if 'summary' in self.column_map and self.column_map['summary'] in data:
sample['label'] = data[self.column_map['summary']]
return sample

def add_noise_to_tgt(self, target):
noise_indices = torch.FloatTensor(
target.size(0)).uniform_() < self.cfg.model.get(
'noise_ratio', 0.0)
target[noise_indices] = torch.randint(
4,
len(self.src_dict) - self.cfg.model.get('num_codes', 8192)
- self.cfg.model.get('num_bins', 1000),
size=(noise_indices.sum(), ))
return target

+ 54
- 5
modelscope/preprocessors/ofa/visual_entailment.py View File

@@ -38,18 +38,64 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(sample['label'])
sample['ref_dict'] = {sample['label']: 1.0}
tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False)

if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([sample['source'], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError

target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
sample['target'] = target_item
sample['prev_output_tokens'] = prev_output_item

if self.constraint_trie is not None:
constraint_mask = torch.zeros(
(len(target_item), len(self.tgt_dict))).bool()
start_idx = len(target_item) - len(tgt_item) - 1
for i in range(
len(target_item) - len(tgt_item) - 1, len(target_item)):
constraint_prefix_token = [
self.tgt_dict.bos()
] + target_item[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask

return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
if 'text2' not in data:
hypothesis = self.pre_caption(data['text'], self.max_src_length)
hypothesis = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get('prompt',
' does the image describe " {} "?')
text = prompt.format(hypothesis)
else:
assert 'text' in data, f'text must be in the input {data.keys()}'
caption = self.pre_caption(data['text2'], self.max_src_length)
hypothesis = self.pre_caption(data['text'], self.max_src_length)
caption = self.pre_caption(data[self.column_map['text2']],
self.max_src_length)
hypothesis = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
text = prompt.format(caption, hypothesis)
@@ -68,4 +114,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
if 'relation' in self.column_map and self.column_map[
'relation'] in data:
sample['label'] = data[self.column_map['relation']]
return sample

+ 94
- 12
modelscope/preprocessors/ofa/visual_grounding.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import numpy as np
import torch
from PIL import Image
from torchvision import transforms
@@ -8,6 +9,7 @@ from torchvision import transforms
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils import transforms as T


class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
@@ -27,24 +29,98 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
"""
super(OfaVisualGroundingPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

self.num_bins = self.cfg.model.get('num_bins', 1000)
if self.mode == ModeKeys.TRAIN:
# for positioning
self.positioning_transform = T.Compose([
T.RandomResize([self.patch_image_size],
max_size=self.patch_image_size),
T.ToTensor(),
T.Normalize(
mean=self.mean,
std=self.std,
max_image_size=self.max_image_size)
])
else:
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
w, h = image.size
boxes_target = {
'boxes': [],
'labels': [],
'area': [],
'size': torch.tensor([h, w])
}
x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split(
',')
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
boxes_target['boxes'] = torch.tensor(
[[float(x0), float(y0), float(x1),
float(y1)]])
boxes_target['labels'] = np.array([0])
area = [(float(x1) - float(x0)) * (float(y1) - float(y0))]
boxes_target['area'] = torch.tensor(area)

patch_image, patch_boxes = self.positioning_transform(
image, boxes_target)
resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1]
quant_x0 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round()))
quant_y0 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round()))
quant_x1 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round()))
quant_y1 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round()))
region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1,
quant_y1)
src_caption = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
src_item = self.tokenize_text(text)
target_item = self.tokenize_text(
region_coord, add_bos=False) # !!! use_bpe=False
prev_output_item = torch.cat([self.bos_item, target_item[:-1]])

sample = {
'source': src_item,
'patch_image': patch_image,
'patch_mask': torch.tensor([True]),
'target': target_item,
'prev_output_tokens': prev_output_item,
'w_resize_ratio': resize_w / w,
'h_resize_ratio': resize_h / h,
'region_coord': region
}
return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
w, h = image.size
patch_image = self.patch_resize_transform(image)
w_resize_ratio = torch.tensor(self.patch_image_size / w)
h_resize_ratio = torch.tensor(self.patch_image_size / h)
src_caption = self.pre_caption(data['text'], self.max_src_length)
src_caption = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
@@ -56,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
'w_resize_ratio': w_resize_ratio,
'h_resize_ratio': h_resize_ratio,
}

if 'region_coord' in self.column_map and self.column_map[
'region_coord'] in data:
x0, y0, x1, y1 = data[
self.column_map['region_coord']].strip().split(',')
sample['label'] = [float(x0), float(y0), float(x1), float(y1)]
return sample

+ 47
- 3
modelscope/preprocessors/ofa/visual_question_answering.py View File

@@ -38,10 +38,52 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
tgt_item = self.tokenize_text(
' {}'.format(sample['label']), add_bos=False, add_eos=False)

if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([sample['source'], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id

sample['prev_output_tokens'] = prev_output_item
sample['target'] = target_item

if self.constraint_trie is not None:
constraint_mask = torch.zeros(
(len(target_item), len(self.tgt_dict))).bool()
start_idx = len(target_item) - len(tgt_item) - 1
for i in range(
len(target_item) - len(tgt_item) - 1, len(target_item)):
constraint_prefix_token = [
self.tgt_dict.bos()
] + target_item[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask

return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
text = ' {}'.format(data['text'])
text = ' {}'.format(data[self.column_map['text']])
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
@@ -57,4 +99,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
if 'answer' in self.column_map and self.column_map['answer'] in data:
sample['label'] = data[self.column_map['answer']]
return sample

+ 5
- 2
modelscope/trainers/audio/kws_farfield_trainer.py View File

@@ -69,11 +69,14 @@ class KWSFarfieldTrainer(BaseTrainer):

super().__init__(cfg_file, arg_parse_fn)

self.model = self.build_model()
self.work_dir = work_dir
# the number of model output dimension
# should update config outside the trainer, if user need more wake word
num_syn = kwargs.get('num_syn', None)
if num_syn:
self.cfg.model.num_syn = num_syn
self._num_classes = self.cfg.model.num_syn
self.model = self.build_model()
self.work_dir = work_dir

if kwargs.get('launcher', None) is not None:
init_dist(kwargs['launcher'])


+ 5
- 1
modelscope/trainers/multi_modal/__init__.py View File

@@ -6,11 +6,15 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .clip import CLIPTrainer
from .team import TEAMImgClsTrainer
from .ofa import OFATrainer
from .mplug import MPlugTrainer

else:
_import_structure = {
'clip': ['CLIPTrainer'],
'team': ['TEAMImgClsTrainer']
'team': ['TEAMImgClsTrainer'],
'ofa': ['OFATrainer'],
'mplug': ['MPlugTrainer'],
}

import sys


+ 3
- 0
modelscope/trainers/multi_modal/mplug/__init__.py View File

@@ -0,0 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .mplug_trainer import MPlugTrainer

+ 40
- 0
modelscope/trainers/multi_modal/mplug/mplug_trainer.py View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from collections.abc import Mapping

import torch

from modelscope.metainfo import Trainers
from modelscope.outputs import OutputKeys
from modelscope.trainers import NlpEpochBasedTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.file_utils import func_receive_dict_inputs


@TRAINERS.register_module(module_name=Trainers.mplug)
class MPlugTrainer(NlpEpochBasedTrainer):

def _decode(self, tokens):
tokenizer = self.eval_preprocessor.tokenizer
return tokenizer.decode(tokens, skip_special_tokens=True)

def evaluation_step(self, data):
model = self.model.module if self._dist else self.model
model.eval()

with torch.no_grad():
if isinstance(
data,
Mapping) and not func_receive_dict_inputs(model.forward):
result = model.forward(**data)
else:
result = model.forward(data)

result[OutputKeys.TEXT] = [
self._decode(seq) for seq in result['sequences']
]
data[OutputKeys.LABELS] = [
self._decode(seq) for seq in data['answer_input_ids']
]

return result

+ 13
- 4
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer):
self,
model: Optional[Union[TorchModel, nn.Module, str]] = None,
cfg_file: Optional[str] = None,
cfg_modify_fn: Optional[Callable] = None,
arg_parse_fn: Optional[Callable] = None,
data_collator: Optional[Union[Callable, Dict[str,
Callable]]] = None,
@@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer):
**kwargs):
model = Model.from_pretrained(model, revision=model_revision)
model_dir = model.model_dir
cfg = Config.from_file(cfg_file)
self.cfg_modify_fn = cfg_modify_fn
cfg = self.rebuild_config(Config.from_file(cfg_file))
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
work_dir = cfg.train.work_dir
else:
@@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer):
tokenizer_files = {
'zh': [
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt',
'config.json'
'config.json', 'ans2label.json'
],
'en': [
'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json',
'ans2label.json'
],
'en':
['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'],
}
for filename in tokenizer_files[cfg.model.get('language', 'en')]:
finetune_file = os.path.join(work_dir, filename)
@@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer):
**kwargs,
)

def rebuild_config(self, cfg: Config):
if self.cfg_modify_fn is not None:
cfg = self.cfg_modify_fn(cfg)
return cfg

def train_step(self, model, inputs):
model.train()
loss, sample_size, logging_output = self.criterion(model, inputs)


+ 35
- 0
modelscope/utils/chinese_utils.py View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.


def is_chinese_char(word: str):
chinese_punctuations = {
',', '。', ';', ':'
'!', '?', '《', '》', '‘', '’', '“', '”', '(', ')', '【', '】'
}
return len(word) == 1 \
and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations)


def remove_space_between_chinese_chars(decoded_str: str):
old_word_list = decoded_str.split(' ')
new_word_list = []
start = -1
for i, word in enumerate(old_word_list):
if is_chinese_char(word):
if start == -1:
start = i
else:
if start != -1:
new_word_list.append(''.join(old_word_list[start:i]))
start = -1
new_word_list.append(word)
if start != -1:
new_word_list.append(''.join(old_word_list[start:]))
return ' '.join(new_word_list).strip()


# add space for each chinese char
def rebuild_chinese_str(string: str):
return ' '.join(''.join([
f' {char} ' if is_chinese_char(char) else char for char in string
]).split())

+ 18
- 0
modelscope/utils/regress_test_utils.py View File

@@ -5,6 +5,7 @@ import hashlib
import os
import pickle
import random
import re
import shutil
import tempfile
from collections import OrderedDict
@@ -759,3 +760,20 @@ def compare_cfg_and_optimizers(baseline_json,
state2, **kwargs) and match

return match


class IgnoreKeyFn:

def __init__(self, keys):
if isinstance(keys, str):
keys = [keys]
self.keys = keys if isinstance(keys, list) else []

def __call__(self, v1output, v2output, key, type):
if key == 'encoder.encoder.layer.0.intermediate.intermediate_act_fn':
print()
for _key in self.keys:
pattern = re.compile(_key)
if key is not None and pattern.fullmatch(key):
return True
return None

+ 7
- 3
tests/pipelines/test_fill_mask.py View File

@@ -11,7 +11,7 @@ from modelscope.pipelines.nlp import FillMaskPipeline
from modelscope.preprocessors import NLPPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -109,7 +109,9 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, f'fill_mask_sbert_{language}'):
pipeline_ins.model,
f'fill_mask_sbert_{language}',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')
@@ -124,7 +126,9 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, f'fill_mask_veco_{language}'):
pipeline_ins.model,
f'fill_mask_veco_{language}',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


+ 10
- 0
tests/pipelines/test_multilingual_named_entity_recognition.py View File

@@ -27,6 +27,9 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase,
viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title'
viet_sentence = 'Nón vành dễ thương cho bé gái'

multilingual_model_id = 'damo/nlp_raner_named-entity-recognition_multilingual-large-generic'
ml_stc = 'সমস্ত বেতন নিলামের সাধারণ ব্যবহারিক উদাহরণ বিভিন্ন পেনি নিলাম / বিডিং ফি নিলাম ওয়েবসাইটে পাওয়া যাবে।'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download_thai(self):
cache_path = snapshot_download(self.thai_tcrf_model_id)
@@ -60,6 +63,13 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase,
task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id)
print(pipeline_ins(input=self.thai_sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_with_model_name_multilingual(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition,
model=self.multilingual_model_id)
print(pipeline_ins(input=self.ml_stc))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download_viet(self):
cache_path = snapshot_download(self.viet_tcrf_model_id)


+ 9
- 1
tests/pipelines/test_named_entity_recognition.py View File

@@ -20,10 +20,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'

english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom'
chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic'
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
sentence = '这与温岭市新河镇的一个神秘的传说有关。'
sentence_en = 'pizza shovel'
sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download(self):
@@ -91,11 +93,17 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.named_entity_recognition, model=self.lcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_with_chinese_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.chinese_model_id)
print(pipeline_ins(input=self.sentence_zh))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_english_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.english_model_id)
print(pipeline_ins(input='pizza shovel'))
print(pipeline_ins(input=self.sentence_en))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):


+ 4
- 3
tests/pipelines/test_nli.py View File

@@ -3,13 +3,12 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -48,7 +47,9 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_with_model_name(self):
pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_nli'):
pipeline_ins.model,
'sbert_nli',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


+ 4
- 2
tests/pipelines/test_sentence_similarity.py View File

@@ -9,7 +9,7 @@ from modelscope.pipelines.nlp import TextClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -54,7 +54,9 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.sentence_similarity, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_sen_sim'):
pipeline_ins.model,
'sbert_sen_sim',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


+ 7
- 3
tests/pipelines/test_word_segmentation.py View File

@@ -9,7 +9,7 @@ from modelscope.pipelines.nlp import WordSegmentationPipeline
from modelscope.preprocessors import TokenClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -48,10 +48,14 @@ class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.word_segmentation, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_ws_zh'):
pipeline_ins.model,
'sbert_ws_zh',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=self.sentence))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_ws_en'):
pipeline_ins.model,
'sbert_ws_en',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=self.sentence_eng))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


+ 4
- 2
tests/pipelines/test_zero_shot_classification.py View File

@@ -9,7 +9,7 @@ from modelscope.pipelines.nlp import ZeroShotClassificationPipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -65,7 +65,9 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_zero_shot'):
pipeline_ins.model,
'sbert_zero_shot',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(
pipeline_ins(
input=self.sentence, candidate_labels=self.labels))


+ 10
- 28
tests/trainers/test_finetune_mplug.py View File

@@ -20,10 +20,7 @@ class TestFinetuneMPlug(unittest.TestCase):
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
from modelscope.utils.constant import DownloadMode
datadict = MsDataset.load(
'coco_captions_small_slice',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
datadict = MsDataset.load('coco_captions_small_slice')
self.train_dataset = MsDataset(
datadict['train'].remap_columns({
'image:FILE': 'image',
@@ -40,18 +37,6 @@ class TestFinetuneMPlug(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()

def _cfg_modify_fn(self, cfg):
cfg.train.hooks = [{
'type': 'CheckpointHook',
'interval': self.max_epochs
}, {
'type': 'TextLoggerHook',
'interval': 1
}, {
'type': 'IterTimerHook'
}]
return cfg

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_caption(self):
kwargs = dict(
@@ -59,11 +44,10 @@ class TestFinetuneMPlug(unittest.TestCase):
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir,
cfg_modify_fn=self._cfg_modify_fn)
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.mplug, default_args=kwargs)
trainer.train()

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@@ -80,7 +64,7 @@ class TestFinetuneMPlug(unittest.TestCase):
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.mplug, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@@ -94,11 +78,10 @@ class TestFinetuneMPlug(unittest.TestCase):
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir,
cfg_modify_fn=self._cfg_modify_fn)
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.mplug, default_args=kwargs)
trainer.train()

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@@ -115,7 +98,7 @@ class TestFinetuneMPlug(unittest.TestCase):
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.mplug, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@@ -129,11 +112,10 @@ class TestFinetuneMPlug(unittest.TestCase):
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir,
cfg_modify_fn=self._cfg_modify_fn)
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.mplug, default_args=kwargs)
trainer.train()

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@@ -150,7 +132,7 @@ class TestFinetuneMPlug(unittest.TestCase):
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.mplug, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)


+ 2
- 0
tests/trainers/test_ofa_trainer.py View File

@@ -9,6 +9,7 @@ from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.hub import read_config
from modelscope.utils.test_utils import test_level


@@ -78,6 +79,7 @@ class TestOfaTrainer(unittest.TestCase):
json.dump(self.finetune_cfg, writer)

pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'

args = dict(
model=pretrained_model,
work_dir=WORKSPACE,


+ 1
- 1
tests/utils/test_ast.py View File

@@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase):
self.assertIsInstance(from_imports, dict)
self.assertIsInstance(decorators, list)
self.assertListEqual(list(set(imports.keys()) - set(['torch'])), [])
self.assertEqual(len(from_imports.keys()), 9)
self.assertEqual(len(from_imports.keys()), 10)
self.assertTrue(from_imports['modelscope.metainfo'] is not None)
self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines'])
self.assertEqual(decorators,


Loading…
Cancel
Save