Browse Source

[to #42322933] add image instance segmentation pipeline and finetune to MaaS-lib

master
hejunjie.hjj 3 years ago
parent
commit
321292ceda
25 changed files with 2368 additions and 0 deletions
  1. +3
    -0
      data/test/images/image_instance_segmentation.jpg
  2. +5
    -0
      modelscope/metainfo.py
  3. +2
    -0
      modelscope/metrics/__init__.py
  4. +1
    -0
      modelscope/metrics/builder.py
  5. +312
    -0
      modelscope/metrics/image_instance_segmentation_metric.py
  6. +2
    -0
      modelscope/models/cv/image_instance_segmentation/__init__.py
  7. +1
    -0
      modelscope/models/cv/image_instance_segmentation/backbones/__init__.py
  8. +694
    -0
      modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py
  9. +266
    -0
      modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py
  10. +2
    -0
      modelscope/models/cv/image_instance_segmentation/datasets/__init__.py
  11. +332
    -0
      modelscope/models/cv/image_instance_segmentation/datasets/dataset.py
  12. +109
    -0
      modelscope/models/cv/image_instance_segmentation/datasets/transforms.py
  13. +49
    -0
      modelscope/models/cv/image_instance_segmentation/model.py
  14. +203
    -0
      modelscope/models/cv/image_instance_segmentation/postprocess_utils.py
  15. +1
    -0
      modelscope/outputs.py
  16. +3
    -0
      modelscope/pipelines/builder.py
  17. +1
    -0
      modelscope/pipelines/cv/__init__.py
  18. +105
    -0
      modelscope/pipelines/cv/image_instance_segmentation_pipeline.py
  19. +1
    -0
      modelscope/preprocessors/__init__.py
  20. +69
    -0
      modelscope/preprocessors/image.py
  21. +1
    -0
      modelscope/trainers/__init__.py
  22. +2
    -0
      modelscope/trainers/cv/__init__.py
  23. +27
    -0
      modelscope/trainers/cv/image_instance_segmentation_trainer.py
  24. +60
    -0
      tests/pipelines/test_image_instance_segmentation.py
  25. +117
    -0
      tests/trainers/test_image_instance_segmentation_trainer.py

+ 3
- 0
data/test/images/image_instance_segmentation.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8e9ab135da7eacabdeeeee11ba4b7bcdd1bfac128cf92a9de9c79f984060ae1e
size 259865

+ 5
- 0
modelscope/metainfo.py View File

@@ -11,6 +11,7 @@ class Models(object):
"""
# vision models
csrnet = 'csrnet'
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'

# nlp models
bert = 'bert'
@@ -67,6 +68,7 @@ class Pipelines(object):
image_super_resolution = 'rrdb-image-super-resolution'
face_image_generation = 'gan-face-image-generation'
style_transfer = 'AAMS-style-transfer'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'

# nlp tasks
sentence_similarity = 'sentence-similarity'
@@ -124,6 +126,7 @@ class Preprocessors(object):
# cv preprocessor
load_image = 'load-image'
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'

# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -157,6 +160,8 @@ class Metrics(object):
# accuracy
accuracy = 'accuracy'

# metric for image instance segmentation task
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
# metrics for sequence classification task
seq_cls_metric = 'seq_cls_metric'
# metrics for token-classification task


+ 2
- 0
modelscope/metrics/__init__.py View File

@@ -1,5 +1,7 @@
from .base import Metric
from .builder import METRICS, build_metric, task_default_metrics
from .image_color_enhance_metric import ImageColorEnhanceMetric
from .image_instance_segmentation_metric import \
ImageInstanceSegmentationCOCOMetric
from .sequence_classification_metric import SequenceClassificationMetric
from .text_generation_metric import TextGenerationMetric

+ 1
- 0
modelscope/metrics/builder.py View File

@@ -18,6 +18,7 @@ class MetricKeys(object):


task_default_metrics = {
Tasks.image_segmentation: [Metrics.image_ins_seg_coco_metric],
Tasks.sentence_similarity: [Metrics.seq_cls_metric],
Tasks.sentiment_classification: [Metrics.seq_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric],


+ 312
- 0
modelscope/metrics/image_instance_segmentation_metric.py View File

@@ -0,0 +1,312 @@
import os.path as osp
import tempfile
from collections import OrderedDict
from typing import Any, Dict

import numpy as np
import pycocotools.mask as mask_util
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from modelscope.fileio import dump, load
from modelscope.metainfo import Metrics
from modelscope.metrics import METRICS, Metric
from modelscope.utils.registry import default_group


@METRICS.register_module(
group_key=default_group, module_name=Metrics.image_ins_seg_coco_metric)
class ImageInstanceSegmentationCOCOMetric(Metric):
"""The metric computation class for COCO-style image instance segmentation.
"""

def __init__(self):
self.ann_file = None
self.classes = None
self.metrics = ['bbox', 'segm']
self.proposal_nums = (100, 300, 1000)
self.iou_thrs = np.linspace(
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
self.results = []

def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]):
result = outputs['eval_result']
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
self.results.extend(result)
if self.ann_file is None:
self.ann_file = outputs['img_metas'][0]['ann_file']
self.classes = outputs['img_metas'][0]['classes']

def evaluate(self):
cocoGt = COCO(self.ann_file)
self.cat_ids = cocoGt.getCatIds(catNms=self.classes)
self.img_ids = cocoGt.getImgIds()

result_files, tmp_dir = self.format_results(self.results, self.img_ids)

eval_results = OrderedDict()
for metric in self.metrics:
iou_type = metric
if metric not in result_files:
raise KeyError(f'{metric} is not in results')
try:
predictions = load(result_files[metric])
if iou_type == 'segm':
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
# When evaluating mask AP, if the results contain bbox,
# cocoapi will use the box area instead of the mask area
# for calculating the instance area. Though the overall AP
# is not affected, this leads to different
# small/medium/large mask AP results.
for x in predictions:
x.pop('bbox')
cocoDt = cocoGt.loadRes(predictions)
except IndexError:
print('The testing results of the whole dataset is empty.')
break

cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
cocoEval.params.catIds = self.cat_ids
cocoEval.params.imgIds = self.img_ids
cocoEval.params.maxDets = list(self.proposal_nums)
cocoEval.params.iouThrs = self.iou_thrs
# mapping of cocoEval.stats
coco_metric_names = {
'mAP': 0,
'mAP_50': 1,
'mAP_75': 2,
'mAP_s': 3,
'mAP_m': 4,
'mAP_l': 5,
'AR@100': 6,
'AR@300': 7,
'AR@1000': 8,
'AR_s@1000': 9,
'AR_m@1000': 10,
'AR_l@1000': 11
}

cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()

metric_items = [
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
]

for metric_item in metric_items:
key = f'{metric}_{metric_item}'
val = float(
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}')
eval_results[key] = val
ap = cocoEval.stats[:6]
eval_results[f'{metric}_mAP_copypaste'] = (
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
f'{ap[4]:.3f} {ap[5]:.3f}')
if tmp_dir is not None:
tmp_dir.cleanup()
return eval_results

def format_results(self, results, img_ids, jsonfile_prefix=None, **kwargs):
"""Format the results to json (standard format for COCO evaluation).

Args:
results (list[tuple | numpy.ndarray]): Testing results of the
dataset.
data_infos(list[tuple | numpy.ndarray]): data information
jsonfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.

Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing \
the json filepaths, tmp_dir is the temporal directory created \
for saving json files when jsonfile_prefix is not specified.
"""
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(img_ids), (
'The length of results is not equal to the dataset len: {} != {}'.
format(len(results), len(img_ids)))

if jsonfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
jsonfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir

def xyxy2xywh(self, bbox):
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
evaluation.

Args:
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
``xyxy`` order.

Returns:
list[float]: The converted bounding boxes, in ``xywh`` order.
"""

_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0],
_bbox[3] - _bbox[1],
]

def _proposal2json(self, results):
"""Convert proposal results to COCO json style."""
json_results = []
for idx in range(len(self.img_ids)):
img_id = self.img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results

def _det2json(self, results):
"""Convert detection results to COCO json style."""
json_results = []
for idx in range(len(self.img_ids)):
img_id = self.img_ids[idx]
result = results[idx]
for label in range(len(result)):
# Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO)
# (assuming the num class of input dataset is no more than 80).
# Recommended manually set `num_classes=${your test dataset class num}` in the
# configuration.json in practice.
if label >= len(self.classes):
break
bboxes = result[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = self.cat_ids[label]
json_results.append(data)
return json_results

def _segm2json(self, results):
"""Convert instance segmentation results to COCO json style."""
bbox_json_results = []
segm_json_results = []
for idx in range(len(self.img_ids)):
img_id = self.img_ids[idx]
det, seg = results[idx]
for label in range(len(det)):
# Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO)
# (assuming the num class of input dataset is no more than 80).
# Recommended manually set `num_classes=${your test dataset class num}` in the
# configuration.json in practice.
if label >= len(self.classes):
break
# bbox results
bboxes = det[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = self.cat_ids[label]
bbox_json_results.append(data)

# segm results
# some detectors use different scores for bbox and mask
if isinstance(seg, tuple):
segms = seg[0][label]
mask_score = seg[1][label]
else:
segms = seg[label]
mask_score = [bbox[4] for bbox in bboxes]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(mask_score[i])
data['category_id'] = self.cat_ids[label]
if isinstance(segms[i]['counts'], bytes):
segms[i]['counts'] = segms[i]['counts'].decode()
data['segmentation'] = segms[i]
segm_json_results.append(data)
return bbox_json_results, segm_json_results

def results2json(self, results, outfile_prefix):
"""Dump the detection results to a COCO style json file.

There are 3 types of results: proposals, bbox predictions, mask
predictions, and they have different data types. This method will
automatically recognize the type, and dump them to json files.

Args:
results (list[list | tuple | ndarray]): Testing results of the
dataset.
outfile_prefix (str): The filename prefix of the json files. If the
prefix is "somepath/xxx", the json files will be named
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
"somepath/xxx.proposal.json".

Returns:
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
values are corresponding filenames.
"""
result_files = dict()
if isinstance(results[0], list):
json_results = self._det2json(results)
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
dump(json_results, result_files['bbox'])
elif isinstance(results[0], tuple):
json_results = self._segm2json(results)
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
result_files['segm'] = f'{outfile_prefix}.segm.json'
dump(json_results[0], result_files['bbox'])
dump(json_results[1], result_files['segm'])
elif isinstance(results[0], np.ndarray):
json_results = self._proposal2json(results)
result_files['proposal'] = f'{outfile_prefix}.proposal.json'
dump(json_results, result_files['proposal'])
else:
raise TypeError('invalid type of results')
return result_files


def encode_mask_results(mask_results):
"""Encode bitmap mask to RLE code.

Args:
mask_results (list | tuple[list]): bitmap mask results.
In mask scoring rcnn, mask_results is a tuple of (segm_results,
segm_cls_score).

Returns:
list | tuple: RLE encoded mask.
"""
if isinstance(mask_results, tuple): # mask scoring
cls_segms, cls_mask_scores = mask_results
else:
cls_segms = mask_results
num_classes = len(cls_segms)
encoded_mask_results = [[] for _ in range(num_classes)]
for i in range(len(cls_segms)):
for cls_segm in cls_segms[i]:
encoded_mask_results[i].append(
mask_util.encode(
np.array(
cls_segm[:, :, np.newaxis], order='F',
dtype='uint8'))[0]) # encoded with RLE
if isinstance(mask_results, tuple):
return encoded_mask_results, cls_mask_scores
else:
return encoded_mask_results

+ 2
- 0
modelscope/models/cv/image_instance_segmentation/__init__.py View File

@@ -0,0 +1,2 @@
from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin
from .model import CascadeMaskRCNNSwinModel

+ 1
- 0
modelscope/models/cv/image_instance_segmentation/backbones/__init__.py View File

@@ -0,0 +1 @@
from .swin_transformer import SwinTransformer

+ 694
- 0
modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py View File

@@ -0,0 +1,694 @@
# Modified from: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class Mlp(nn.Module):
""" Multilayer perceptron."""

def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size

Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows


def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image

Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x


class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.

Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""

def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):

super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :,
0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index',
relative_position_index)

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x, mask=None):
""" Forward function.

Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)

q = q * self.scale
attn = (q @ k.transpose(-2, -1))

relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)

if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)

attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.

Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""

def __init__(self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'

self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)

self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)

self.H = None
self.W = None

def forward(self, x, mask_matrix):
""" Forward function.

Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, 'input feature has wrong size'

shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)

# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape

# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None

# partition windows
x_windows = window_partition(
shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size,
C) # nW*B, window_size*window_size, C

# W-MSA/SW-MSA
attn_windows = self.attn(
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C

# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
Wp) # B H' W' C

# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x

if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()

x = x.view(B, H * W, C)

# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))

return x


class PatchMerging(nn.Module):
""" Patch Merging Layer

Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""

def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)

def forward(self, x, H, W):
""" Forward function.

Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'

x = x.view(B, H, W, C)

# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C

x = self.norm(x)
x = self.reduction(x)

return x


class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.

Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""

def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint

# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer) for i in range(depth)
])

# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None

def forward(self, x, H, W):
""" Forward function.

Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""

# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1

mask_windows = window_partition(
img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))

for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W


class PatchEmbed(nn.Module):
""" Image to Patch Embedding

Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""

def __init__(self,
patch_size=4,
in_chans=3,
embed_dim=96,
norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size

self.in_chans = in_chans
self.embed_dim = embed_dim

self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None

def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x,
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

return x


class SwinTransformer(nn.Module):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030

Inspiration from
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection

Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""

def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False):
super().__init__()

self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages

# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)

# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1]
]

self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0],
patches_resolution[1]))
trunc_normal_(self.absolute_pos_embed, std=.02)

self.pos_drop = nn.Dropout(p=drop_rate)

# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule

# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if
(i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)

num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
self.num_features = num_features

# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)

self._freeze_stages()

def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False

if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False

if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False

def init_weights(self):
"""Initialize the weights in backbone."""

def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

self.apply(_init_weights)

def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)

Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1,
2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)

outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)

if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)

out = x_out.view(-1, H, W,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)

return tuple(outs)

def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()

+ 266
- 0
modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py View File

@@ -0,0 +1,266 @@
import os
from collections import OrderedDict

import torch
import torch.distributed as dist
import torch.nn as nn

from modelscope.models.cv.image_instance_segmentation.backbones import \
SwinTransformer
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger

logger = get_logger()


def build_backbone(cfg):
assert isinstance(cfg, dict)
cfg = cfg.copy()
type = cfg.pop('type')
if type == 'SwinTransformer':
return SwinTransformer(**cfg)
else:
raise ValueError(f'backbone \'{type}\' is not supported.')


def build_neck(cfg):
assert isinstance(cfg, dict)
cfg = cfg.copy()
type = cfg.pop('type')
if type == 'FPN':
from mmdet.models import FPN
return FPN(**cfg)
else:
raise ValueError(f'neck \'{type}\' is not supported.')


def build_rpn_head(cfg):
assert isinstance(cfg, dict)
cfg = cfg.copy()
type = cfg.pop('type')
if type == 'RPNHead':
from mmdet.models import RPNHead
return RPNHead(**cfg)
else:
raise ValueError(f'rpn head \'{type}\' is not supported.')


def build_roi_head(cfg):
assert isinstance(cfg, dict)
cfg = cfg.copy()
type = cfg.pop('type')
if type == 'CascadeRoIHead':
from mmdet.models import CascadeRoIHead
return CascadeRoIHead(**cfg)
else:
raise ValueError(f'roi head \'{type}\' is not supported.')


class CascadeMaskRCNNSwin(nn.Module):

def __init__(self,
backbone,
neck,
rpn_head,
roi_head,
pretrained=None,
**kwargs):
"""
Args:
backbone (dict): backbone config.
neck (dict): neck config.
rpn_head (dict): rpn_head config.
roi_head (dict): roi_head config.
pretrained (bool): whether to use pretrained model
"""
super(CascadeMaskRCNNSwin, self).__init__()

self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.rpn_head = build_rpn_head(rpn_head)
self.roi_head = build_roi_head(roi_head)

self.classes = kwargs.pop('classes', None)

if pretrained:
assert 'model_dir' in kwargs, 'pretrained model dir is missing.'
model_path = os.path.join(kwargs['model_dir'],
ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
weight = torch.load(model_path)['state_dict']
tgt_weight = self.state_dict()
for name in list(weight.keys()):
if name in tgt_weight:
load_size = weight[name].size()
tgt_size = tgt_weight[name].size()
mis_match = False
if len(load_size) != len(tgt_size):
mis_match = True
else:
for n1, n2 in zip(load_size, tgt_size):
if n1 != n2:
mis_match = True
break
if mis_match:
logger.info(f'size mismatch for {name}, skip loading.')
del weight[name]

self.load_state_dict(weight, strict=False)
logger.info('load model done')

from mmcv.parallel import DataContainer, scatter

self.data_container = DataContainer
self.scatter = scatter

def extract_feat(self, img):
x = self.backbone(img)
x = self.neck(x)
return x

def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.

img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.

gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

gt_labels (list[Tensor]): class indices corresponding to each box

gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.

gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.

proposals : override rpn proposals with custom proposals. Use when
`with_rpn` is False.

Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)

losses = dict()

# RPN forward and loss
proposal_cfg = self.rpn_head.train_cfg.get('rpn_proposal',
self.rpn_head.test_cfg)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg,
**kwargs)
losses.update(rpn_losses)

roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)

return losses

def forward_test(self, img, img_metas, proposals=None, rescale=True):

x = self.extract_feat(img)
if proposals is None:
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
else:
proposal_list = proposals

result = self.roi_head.simple_test(
x, proposal_list, img_metas, rescale=rescale)
return dict(eval_result=result, img_metas=img_metas)

def forward(self, img, img_metas, **kwargs):

# currently only support cpu or single gpu
if isinstance(img, self.data_container):
img = img.data[0]
if isinstance(img_metas, self.data_container):
img_metas = img_metas.data[0]
for k, w in kwargs.items():
if isinstance(w, self.data_container):
w = w.data[0]
kwargs[k] = w

if next(self.parameters()).is_cuda:
device = next(self.parameters()).device
img = self.scatter(img, [device])[0]
img_metas = self.scatter(img_metas, [device])[0]
for k, w in kwargs.items():
kwargs[k] = self.scatter(w, [device])[0]

if self.training:
losses = self.forward_train(img, img_metas, **kwargs)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(img_metas))
return outputs
else:
return self.forward_test(img, img_metas, **kwargs)

def _parse_losses(self, losses):

log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')

loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)

log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()

return loss, log_vars

def train_step(self, data, optimizer):

losses = self(**data)
loss, log_vars = self._parse_losses(losses)

outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))

return outputs

def val_step(self, data, optimizer=None):

losses = self(**data)
loss, log_vars = self._parse_losses(losses)

outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))

return outputs

+ 2
- 0
modelscope/models/cv/image_instance_segmentation/datasets/__init__.py View File

@@ -0,0 +1,2 @@
from .dataset import ImageInstanceSegmentationCocoDataset
from .transforms import build_preprocess_transform

+ 332
- 0
modelscope/models/cv/image_instance_segmentation/datasets/dataset.py View File

@@ -0,0 +1,332 @@
import os.path as osp

import numpy as np
from pycocotools.coco import COCO
from torch.utils.data import Dataset


class ImageInstanceSegmentationCocoDataset(Dataset):
"""Coco-style dataset for image instance segmentation.

Args:
ann_file (str): Annotation file path.
classes (Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
data_root (str, optional): Data root for ``ann_file``,
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
test_mode (bool, optional): If set True, annotation will not be loaded.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
"""

CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')

def __init__(self,
ann_file,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
test_mode=False,
filter_empty_gt=True):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.CLASSES = self.get_classes(classes)

# join paths if data_root is specified
if self.data_root is not None:
if not osp.isabs(self.ann_file):
self.ann_file = osp.join(self.data_root, self.ann_file)
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
self.img_prefix = osp.join(self.data_root, self.img_prefix)
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
self.seg_prefix = osp.join(self.data_root, self.seg_prefix)

# load annotations
self.data_infos = self.load_annotations(self.ann_file)

# filter images too small and containing no annotations
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
# set group flag for the sampler
self._set_group_flag()

self.preprocessor = None

def __len__(self):
"""Total number of samples of data."""
return len(self.data_infos)

def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.

Args:
ann_file (str): Path of annotation file.

Returns:
list[dict]: Annotation info from COCO api.
"""

self.coco = COCO(ann_file)
# The order of returned `cat_ids` will not
# change with the order of the CLASSES
self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES)

self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.getImgIds()
data_infos = []
total_ann_ids = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info['filename'] = info['file_name']
info['ann_file'] = ann_file
info['classes'] = self.CLASSES
data_infos.append(info)
ann_ids = self.coco.getAnnIds(imgIds=[i])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
return data_infos

def get_ann_info(self, idx):
"""Get COCO annotation by index.

Args:
idx (int): Index of data.

Returns:
dict: Annotation info of specified index.
"""

img_id = self.data_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info)

def get_cat_ids(self, idx):
"""Get COCO category ids by index.

Args:
idx (int): Index of data.

Returns:
list[int]: All categories in the image of specified index.
"""

img_id = self.data_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
return [ann['category_id'] for ann in ann_info]

def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['img_prefix'] = self.img_prefix
results['seg_prefix'] = self.seg_prefix
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []

def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.catToImgs[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann

valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds

def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.

Args:
ann_info (list[dict]): Annotation info of an image.

Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
labels, masks, seg_map. "masks" are raw annotations and not \
decoded into binary masks.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann.get('segmentation', None))

if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)

if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

seg_map = img_info['filename'].replace('jpg', 'png')

ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks=gt_masks_ann,
seg_map=seg_map)

return ann

def _set_group_flag(self):
"""Set flag according to image aspect ratio.

Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
for i in range(len(self)):
img_info = self.data_infos[i]
if img_info['width'] / img_info['height'] > 1:
self.flag[i] = 1

def _rand_another(self, idx):
"""Get another random index from the same group as the given index."""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)

def __getitem__(self, idx):
"""Get training/test data after pipeline.

Args:
idx (int): Index of data.

Returns:
dict: Training/test data (with annotation if `test_mode` is set \
True).
"""

if self.test_mode:
return self.prepare_test_img(idx)
while True:
data = self.prepare_train_img(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data

def prepare_train_img(self, idx):
"""Get training data and annotations after pipeline.

Args:
idx (int): Index of data.

Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""

img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
if self.preprocessor is None:
return results
self.preprocessor.train()
return self.preprocessor(results)

def prepare_test_img(self, idx):
"""Get testing data after pipeline.

Args:
idx (int): Index of data.

Returns:
dict: Testing data after pipeline with new keys introduced by \
pipeline.
"""

img_info = self.data_infos[idx]
results = dict(img_info=img_info)
self.pre_pipeline(results)
if self.preprocessor is None:
return results
self.preprocessor.eval()
results = self.preprocessor(results)
return results

@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.

Args:
classes (Sequence[str] | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is
a tuple or list, override the CLASSES defined by the dataset.

Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
return cls.CLASSES

if isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')

return class_names

def to_torch_dataset(self, preprocessors=None):
self.preprocessor = preprocessors
return self

+ 109
- 0
modelscope/models/cv/image_instance_segmentation/datasets/transforms.py View File

@@ -0,0 +1,109 @@
import os.path as osp

import numpy as np

from modelscope.fileio import File


def build_preprocess_transform(cfg):
assert isinstance(cfg, dict)
cfg = cfg.copy()
type = cfg.pop('type')
if type == 'LoadImageFromFile':
return LoadImageFromFile(**cfg)
elif type == 'LoadAnnotations':
from mmdet.datasets.pipelines import LoadAnnotations
return LoadAnnotations(**cfg)
elif type == 'Resize':
if 'img_scale' in cfg:
if isinstance(cfg.img_scale[0], list):
elems = []
for elem in cfg.img_scale:
elems.append(tuple(elem))
cfg.img_scale = elems
else:
cfg.img_scale = tuple(cfg.img_scale)
from mmdet.datasets.pipelines import Resize
return Resize(**cfg)
elif type == 'RandomFlip':
from mmdet.datasets.pipelines import RandomFlip
return RandomFlip(**cfg)
elif type == 'Normalize':
from mmdet.datasets.pipelines import Normalize
return Normalize(**cfg)
elif type == 'Pad':
from mmdet.datasets.pipelines import Pad
return Pad(**cfg)
elif type == 'DefaultFormatBundle':
from mmdet.datasets.pipelines import DefaultFormatBundle
return DefaultFormatBundle(**cfg)
elif type == 'ImageToTensor':
from mmdet.datasets.pipelines import ImageToTensor
return ImageToTensor(**cfg)
elif type == 'Collect':
from mmdet.datasets.pipelines import Collect
return Collect(**cfg)
else:
raise ValueError(f'preprocess transform \'{type}\' is not supported.')


class LoadImageFromFile:
"""Load an image from file.

Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).

Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""

def __init__(self, to_float32=False, mode='rgb'):
self.to_float32 = to_float32
self.mode = mode

from mmcv import imfrombytes

self.imfrombytes = imfrombytes

def __call__(self, results):
"""Call functions to load image and get image meta information.

Args:
results (dict): Result dict from :obj:`ImageInstanceSegmentationDataset`.

Returns:
dict: The dict contains loaded image and meta information.
"""

if results['img_prefix'] is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']

img_bytes = File.read(filename)

img = self.imfrombytes(img_bytes, 'color', 'bgr', backend='pillow')

if self.to_float32:
img = img.astype(np.float32)

results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
results['ann_file'] = results['img_info']['ann_file']
results['classes'] = results['img_info']['classes']
return results

def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"mode='{self.mode}'")
return repr_str

+ 49
- 0
modelscope/models/cv/image_instance_segmentation/model.py View File

@@ -0,0 +1,49 @@
import os
from typing import Any, Dict

import torch

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.image_instance_segmentation import \
CascadeMaskRCNNSwin
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks


@MODELS.register_module(
Tasks.image_segmentation, module_name=Models.cascade_mask_rcnn_swin)
class CascadeMaskRCNNSwinModel(TorchModel):

def __init__(self, model_dir=None, *args, **kwargs):
"""
Args:
model_dir (str): model directory.

"""
super(CascadeMaskRCNNSwinModel, self).__init__(
model_dir=model_dir, *args, **kwargs)

if 'backbone' not in kwargs:
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)
model_cfg = cfg.model
kwargs.update(model_cfg)

self.model = CascadeMaskRCNNSwin(model_dir=model_dir, **kwargs)

self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
output = self.model(**input)
return output

def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:

return input

def compute_loss(self, outputs: Dict[str, Any], labels):
pass

+ 203
- 0
modelscope/models/cv/image_instance_segmentation/postprocess_utils.py View File

@@ -0,0 +1,203 @@
import itertools

import cv2
import numpy as np
import pycocotools.mask as maskUtils
import torch

from modelscope.outputs import OutputKeys


def get_seg_bboxes(bboxes, labels, segms=None, class_names=None, score_thr=0.):
assert bboxes.ndim == 2, \
f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
assert labels.ndim == 1, \
f' labels ndim should be 1, but its ndim is {labels.ndim}.'
assert bboxes.shape[0] == labels.shape[0], \
'bboxes.shape[0] and labels.shape[0] should have the same length.'
assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'

if score_thr > 0:
assert bboxes.shape[1] == 5
scores = bboxes[:, -1]
inds = scores > score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]
if segms is not None:
segms = segms[inds, ...]

bboxes_names = []
for i, (bbox, label) in enumerate(zip(bboxes, labels)):
label_name = class_names[
label] if class_names is not None else f'class {label}'
bbox = [0 if b < 0 else b for b in list(bbox)]
bbox.append(label_name)
bbox.append(segms[i].astype(bool))
bboxes_names.append(bbox)

return bboxes_names


def get_img_seg_results(det_rawdata=None,
class_names=None,
score_thr=0.3,
is_decode=True):
'''
Get all boxes of one image.
score_thr: Classification probability threshold。
output format: [ [x1,y1,x2,y2, prob, cls_name, mask], [x1,y1,x2,y2, prob, cls_name, mask], ... ]
'''
assert det_rawdata is not None, 'det_rawdata should be not None.'
assert class_names is not None, 'class_names should be not None.'

if isinstance(det_rawdata, tuple):
bbox_result, segm_result = det_rawdata
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
else:
bbox_result, segm_result = det_rawdata, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)

segms = None
if segm_result is not None and len(labels) > 0: # non empty
segms = list(itertools.chain(*segm_result))
if is_decode:
segms = maskUtils.decode(segms)
segms = segms.transpose(2, 0, 1)
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)

bboxes_names = get_seg_bboxes(
bboxes,
labels,
segms=segms,
class_names=class_names,
score_thr=score_thr)

return bboxes_names


def get_img_ins_seg_result(img_seg_result=None,
class_names=None,
score_thr=0.3):
assert img_seg_result is not None, 'img_seg_result should be not None.'
assert class_names is not None, 'class_names should be not None.'

img_seg_result = get_img_seg_results(
det_rawdata=(img_seg_result[0], img_seg_result[1]),
class_names=class_names,
score_thr=score_thr,
is_decode=False)

results_dict = {
OutputKeys.BOXES: [],
OutputKeys.MASKS: [],
OutputKeys.LABELS: [],
OutputKeys.SCORES: []
}
for seg_result in img_seg_result:

box = {
'x': np.int(seg_result[0]),
'y': np.int(seg_result[1]),
'w': np.int(seg_result[2] - seg_result[0]),
'h': np.int(seg_result[3] - seg_result[1])
}
score = np.float(seg_result[4])
category = seg_result[5]

mask = np.array(seg_result[6], order='F', dtype='uint8')
mask = mask.astype(np.float)

results_dict[OutputKeys.BOXES].append(box)
results_dict[OutputKeys.MASKS].append(mask)
results_dict[OutputKeys.SCORES].append(score)
results_dict[OutputKeys.LABELS].append(category)

return results_dict


def show_result(
img,
result,
out_file='result.jpg',
show_box=True,
show_label=True,
show_score=True,
alpha=0.5,
fontScale=0.5,
fontFace=cv2.FONT_HERSHEY_COMPLEX_SMALL,
thickness=1,
):

assert isinstance(img, (str, np.ndarray)), \
f'img must be str or np.ndarray, but got {type(img)}.'

if isinstance(img, str):
img = cv2.imread(img)
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

img = img.astype(np.float32)

labels = result[OutputKeys.LABELS]
scores = result[OutputKeys.SCORES]
boxes = result[OutputKeys.BOXES]
masks = result[OutputKeys.MASKS]

for label, score, box, mask in zip(labels, scores, boxes, masks):

random_color = np.array([
np.random.random() * 255.0,
np.random.random() * 255.0,
np.random.random() * 255.0
])

x1 = int(box['x'])
y1 = int(box['y'])
w = int(box['w'])
h = int(box['h'])
x2 = x1 + w
y2 = y1 + h

if show_box:
cv2.rectangle(
img, (x1, y1), (x2, y2), random_color, thickness=thickness)
if show_label or show_score:
if show_label and show_score:
text = '{}|{}'.format(label, round(float(score), 2))
elif show_label:
text = '{}'.format(label)
else:
text = '{}'.format(round(float(score), 2))

retval, baseLine = cv2.getTextSize(
text,
fontFace=fontFace,
fontScale=fontScale,
thickness=thickness)
cv2.rectangle(
img, (x1, y1 - retval[1] - baseLine), (x1 + retval[0], y1),
thickness=-1,
color=(0, 0, 0))
cv2.putText(
img,
text, (x1, y1 - baseLine),
fontScale=fontScale,
fontFace=fontFace,
thickness=thickness,
color=random_color)

idx = np.nonzero(mask)
img[idx[0], idx[1], :] *= 1.0 - alpha
img[idx[0], idx[1], :] += alpha * random_color

cv2.imwrite(out_file, img)

+ 1
- 0
modelscope/outputs.py View File

@@ -13,6 +13,7 @@ class OutputKeys(object):
POSES = 'poses'
CAPTION = 'caption'
BOXES = 'boxes'
MASKS = 'masks'
TEXT = 'text'
POLYGONS = 'polygons'
OUTPUT = 'output'


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -76,6 +76,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_daflow_virtual-tryon_base'),
Tasks.image_colorization: (Pipelines.image_colorization,
'damo/cv_unet_image-colorization'),
Tasks.image_segmentation:
(Pipelines.image_instance_segmentation,
'damo/cv_swin-b_image-instance-segmentation_coco'),
Tasks.style_transfer: (Pipelines.style_transfer,
'damo/cv_aams_style-transfer_damo'),
Tasks.face_image_generation: (Pipelines.face_image_generation,


+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -11,6 +11,7 @@ try:
from .image_colorization_pipeline import ImageColorizationPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass


+ 105
- 0
modelscope/pipelines/cv/image_instance_segmentation_pipeline.py View File

@@ -0,0 +1,105 @@
import os
from typing import Any, Dict, Optional, Union

import cv2
import numpy as np
import torch
from PIL import Image

from modelscope.metainfo import Pipelines
from modelscope.models.cv.image_instance_segmentation.model import \
CascadeMaskRCNNSwinModel
from modelscope.models.cv.image_instance_segmentation.postprocess_utils import \
get_img_ins_seg_result
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import (ImageInstanceSegmentationPreprocessor,
build_preprocessor, load_image)
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_segmentation,
module_name=Pipelines.image_instance_segmentation)
class ImageInstanceSegmentationPipeline(Pipeline):

def __init__(self,
model: Union[CascadeMaskRCNNSwinModel, str],
preprocessor: Optional[
ImageInstanceSegmentationPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a image instance segmentation pipeline for prediction

Args:
model (CascadeMaskRCNNSwinModel | str): a model instance
preprocessor (CascadeMaskRCNNSwinPreprocessor | None): a preprocessor instance
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

if preprocessor is None:
config_path = os.path.join(self.model.model_dir,
ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)
self.preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
else:
self.preprocessor = preprocessor

self.preprocessor.eval()
self.model.eval()

def _collate_fn(self, data):
# don't require collating
return data

def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
filename = None
img = None
if isinstance(input, str):
filename = input
img = np.array(load_image(input))
img = img[:, :, ::-1] # convert to bgr
elif isinstance(input, Image.Image):
img = np.array(input.convert('RGB'))
img = img[:, :, ::-1] # convert to bgr
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

result = {
'img': img,
'img_shape': img.shape,
'ori_shape': img.shape,
'img_fields': ['img'],
'img_prefix': '',
'img_info': {
'filename': filename,
'ann_file': None,
'classes': None
},
}
result = self.preprocessor(result)

# stacked as a batch
result['img'] = torch.stack([result['img']], dim=0)
result['img_metas'] = [result['img_metas'].data]

return result

def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
output = self.model(input)
return output

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
result = get_img_ins_seg_result(
img_seg_result=inputs['eval_result'][0],
class_names=self.model.model.classes)
return result

+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -20,6 +20,7 @@ try:
from .space.dialog_modeling_preprocessor import * # noqa F403
from .space.dialog_state_tracking_preprocessor import * # noqa F403
from .image import ImageColorEnhanceFinetunePreprocessor
from .image import ImageInstanceSegmentationPreprocessor
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
print(TENSORFLOW_IMPORT_ERROR.format('tts'))


+ 69
- 0
modelscope/preprocessors/image.py View File

@@ -136,3 +136,72 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor):
"""

return data


@PREPROCESSORS.register_module(
Fields.cv,
module_name=Preprocessors.image_instance_segmentation_preprocessor)
class ImageInstanceSegmentationPreprocessor(Preprocessor):

def __init__(self, *args, **kwargs):
"""image instance segmentation preprocessor in the fine-tune scenario
"""

super().__init__(*args, **kwargs)

self.training = kwargs.pop('training', True)
self.preprocessor_train_cfg = kwargs.pop('train', None)
self.preprocessor_test_cfg = kwargs.pop('val', None)

self.train_transforms = []
self.test_transforms = []

from modelscope.models.cv.image_instance_segmentation.datasets import \
build_preprocess_transform

if self.preprocessor_train_cfg is not None:
if isinstance(self.preprocessor_train_cfg, dict):
self.preprocessor_train_cfg = [self.preprocessor_train_cfg]
for cfg in self.preprocessor_train_cfg:
transform = build_preprocess_transform(cfg)
self.train_transforms.append(transform)

if self.preprocessor_test_cfg is not None:
if isinstance(self.preprocessor_test_cfg, dict):
self.preprocessor_test_cfg = [self.preprocessor_test_cfg]
for cfg in self.preprocessor_test_cfg:
transform = build_preprocess_transform(cfg)
self.test_transforms.append(transform)

def train(self):
self.training = True
return

def eval(self):
self.training = False
return

@type_assert(object, object)
def __call__(self, results: Dict[str, Any]):
"""process the raw input data

Args:
results (dict): Result dict from loading pipeline.

Returns:
Dict[str, Any] | None: the preprocessed data
"""

if self.training:
transforms = self.train_transforms
else:
transforms = self.test_transforms

for t in transforms:

results = t(results)

if results is None:
return None

return results

+ 1
- 0
modelscope/trainers/__init__.py View File

@@ -1,4 +1,5 @@
from .base import DummyTrainer
from .builder import build_trainer
from .cv import ImageInstanceSegmentationTrainer
from .nlp import SequenceClassificationTrainer
from .trainer import EpochBasedTrainer

+ 2
- 0
modelscope/trainers/cv/__init__.py View File

@@ -0,0 +1,2 @@
from .image_instance_segmentation_trainer import \
ImageInstanceSegmentationTrainer

+ 27
- 0
modelscope/trainers/cv/image_instance_segmentation_trainer.py View File

@@ -0,0 +1,27 @@
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.trainer import EpochBasedTrainer


@TRAINERS.register_module(module_name='image-instance-segmentation')
class ImageInstanceSegmentationTrainer(EpochBasedTrainer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def collate_fn(self, data):
# we skip this func due to some special data type, e.g., BitmapMasks
return data

def train(self, *args, **kwargs):
super().train(*args, **kwargs)

def evaluate(self, *args, **kwargs):
metric_values = super().evaluate(*args, **kwargs)
return metric_values

def prediction_step(self, model, inputs):
pass

def to_task_dataset(self, datasets, mode, preprocessor=None):
# wait for dataset interface to become stable...
return datasets.to_torch_dataset(preprocessor)

+ 60
- 0
tests/pipelines/test_image_instance_segmentation.py View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.cv.image_instance_segmentation.model import \
CascadeMaskRCNNSwinModel
from modelscope.outputs import OutputKeys
from modelscope.pipelines import ImageInstanceSegmentationPipeline, pipeline
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class ImageInstanceSegmentationTest(unittest.TestCase):
model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'
image = 'data/test/images/image_instance_segmentation.jpg'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)
preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
pipeline_ins = pipeline(
task=Tasks.image_segmentation,
model=model,
preprocessor=preprocessor)
print(pipeline_ins(input=self.image)[OutputKeys.LABELS])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.image_segmentation, model=self.model_id)
print(pipeline_ins(input=self.image)[OutputKeys.LABELS])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.image_segmentation)
print(pipeline_ins(input=self.image)[OutputKeys.LABELS])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)
preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
model = CascadeMaskRCNNSwinModel(cache_path)
pipeline1 = ImageInstanceSegmentationPipeline(
model, preprocessor=preprocessor)
pipeline2 = pipeline(
Tasks.image_segmentation, model=model, preprocessor=preprocessor)
print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}')
print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}')


if __name__ == '__main__':
unittest.main()

+ 117
- 0
tests/trainers/test_image_instance_segmentation_trainer.py View File

@@ -0,0 +1,117 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import zipfile
from functools import partial

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.image_instance_segmentation import \
CascadeMaskRCNNSwinModel
from modelscope.models.cv.image_instance_segmentation.datasets import \
ImageInstanceSegmentationCocoDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestImageInstanceSegmentationTrainer(unittest.TestCase):

model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

cache_path = snapshot_download(self.model_id)
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)

data_root = cfg.dataset.data_root
classes = tuple(cfg.dataset.classes)
max_epochs = cfg.train.max_epochs
samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu

if data_root is None:
# use default toy data
dataset_path = os.path.join(cache_path, 'toydata.zip')
with zipfile.ZipFile(dataset_path, 'r') as zipf:
zipf.extractall(cache_path)
data_root = cache_path + '/toydata/'
classes = ('Cat', 'Dog')

self.train_dataset = ImageInstanceSegmentationCocoDataset(
data_root + 'annotations/instances_train.json',
classes=classes,
data_root=data_root,
img_prefix=data_root + 'images/train/',
seg_prefix=None,
test_mode=False)

self.eval_dataset = ImageInstanceSegmentationCocoDataset(
data_root + 'annotations/instances_val.json',
classes=classes,
data_root=data_root,
img_prefix=data_root + 'images/val/',
seg_prefix=None,
test_mode=True)

from mmcv.parallel import collate

self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu)

self.max_epochs = max_epochs

self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
data_collator=self.collate_fn,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(
name='image-instance-segmentation', default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download(self.model_id)
model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
data_collator=self.collate_fn,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(
name='image-instance-segmentation', default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save