diff --git a/data/test/images/image_instance_segmentation.jpg b/data/test/images/image_instance_segmentation.jpg new file mode 100644 index 00000000..f390fc90 --- /dev/null +++ b/data/test/images/image_instance_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e9ab135da7eacabdeeeee11ba4b7bcdd1bfac128cf92a9de9c79f984060ae1e +size 259865 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 345a49a6..ad914618 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -11,6 +11,7 @@ class Models(object): """ # vision models csrnet = 'csrnet' + cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' # nlp models bert = 'bert' @@ -67,6 +68,7 @@ class Pipelines(object): image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' + image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' # nlp tasks sentence_similarity = 'sentence-similarity' @@ -124,6 +126,7 @@ class Preprocessors(object): # cv preprocessor load_image = 'load-image' image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' + image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' @@ -157,6 +160,8 @@ class Metrics(object): # accuracy accuracy = 'accuracy' + # metric for image instance segmentation task + image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' # metrics for sequence classification task seq_cls_metric = 'seq_cls_metric' # metrics for token-classification task diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 0c7dec95..1f159687 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -1,5 +1,7 @@ from .base import Metric from .builder import METRICS, build_metric, task_default_metrics from .image_color_enhance_metric import ImageColorEnhanceMetric +from .image_instance_segmentation_metric import \ + ImageInstanceSegmentationCOCOMetric from .sequence_classification_metric import SequenceClassificationMetric from .text_generation_metric import TextGenerationMetric diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 6cb8be7d..c7c62a7a 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -18,6 +18,7 @@ class MetricKeys(object): task_default_metrics = { + Tasks.image_segmentation: [Metrics.image_ins_seg_coco_metric], Tasks.sentence_similarity: [Metrics.seq_cls_metric], Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.text_generation: [Metrics.text_gen_metric], diff --git a/modelscope/metrics/image_instance_segmentation_metric.py b/modelscope/metrics/image_instance_segmentation_metric.py new file mode 100644 index 00000000..7deafbce --- /dev/null +++ b/modelscope/metrics/image_instance_segmentation_metric.py @@ -0,0 +1,312 @@ +import os.path as osp +import tempfile +from collections import OrderedDict +from typing import Any, Dict + +import numpy as np +import pycocotools.mask as mask_util +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from modelscope.fileio import dump, load +from modelscope.metainfo import Metrics +from modelscope.metrics import METRICS, Metric +from modelscope.utils.registry import default_group + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_ins_seg_coco_metric) +class ImageInstanceSegmentationCOCOMetric(Metric): + """The metric computation class for COCO-style image instance segmentation. + """ + + def __init__(self): + self.ann_file = None + self.classes = None + self.metrics = ['bbox', 'segm'] + self.proposal_nums = (100, 300, 1000) + self.iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.results = [] + + def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]): + result = outputs['eval_result'] + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + self.results.extend(result) + if self.ann_file is None: + self.ann_file = outputs['img_metas'][0]['ann_file'] + self.classes = outputs['img_metas'][0]['classes'] + + def evaluate(self): + cocoGt = COCO(self.ann_file) + self.cat_ids = cocoGt.getCatIds(catNms=self.classes) + self.img_ids = cocoGt.getImgIds() + + result_files, tmp_dir = self.format_results(self.results, self.img_ids) + + eval_results = OrderedDict() + for metric in self.metrics: + iou_type = metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + cocoDt = cocoGt.loadRes(predictions) + except IndexError: + print('The testing results of the whole dataset is empty.') + break + + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + cocoEval.params.catIds = self.cat_ids + cocoEval.params.imgIds = self.img_ids + cocoEval.params.maxDets = list(self.proposal_nums) + cocoEval.params.iouThrs = self.iou_thrs + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}') + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results[f'{metric}_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results + + def format_results(self, results, img_ids, jsonfile_prefix=None, **kwargs): + """Format the results to json (standard format for COCO evaluation). + + Args: + results (list[tuple | numpy.ndarray]): Testing results of the + dataset. + data_infos(list[tuple | numpy.ndarray]): data information + jsonfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a dict containing \ + the json filepaths, tmp_dir is the temporal directory created \ + for saving json files when jsonfile_prefix is not specified. + """ + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(img_ids), ( + 'The length of results is not equal to the dataset len: {} != {}'. + format(len(results), len(img_ids))) + + if jsonfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + jsonfile_prefix = osp.join(tmp_dir.name, 'results') + else: + tmp_dir = None + result_files = self.results2json(results, jsonfile_prefix) + return result_files, tmp_dir + + def xyxy2xywh(self, bbox): + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def _proposal2json(self, results): + """Convert proposal results to COCO json style.""" + json_results = [] + for idx in range(len(self.img_ids)): + img_id = self.img_ids[idx] + bboxes = results[idx] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = 1 + json_results.append(data) + return json_results + + def _det2json(self, results): + """Convert detection results to COCO json style.""" + json_results = [] + for idx in range(len(self.img_ids)): + img_id = self.img_ids[idx] + result = results[idx] + for label in range(len(result)): + # Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) + # (assuming the num class of input dataset is no more than 80). + # Recommended manually set `num_classes=${your test dataset class num}` in the + # configuration.json in practice. + if label >= len(self.classes): + break + bboxes = result[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = self.cat_ids[label] + json_results.append(data) + return json_results + + def _segm2json(self, results): + """Convert instance segmentation results to COCO json style.""" + bbox_json_results = [] + segm_json_results = [] + for idx in range(len(self.img_ids)): + img_id = self.img_ids[idx] + det, seg = results[idx] + for label in range(len(det)): + # Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) + # (assuming the num class of input dataset is no more than 80). + # Recommended manually set `num_classes=${your test dataset class num}` in the + # configuration.json in practice. + if label >= len(self.classes): + break + # bbox results + bboxes = det[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = self.cat_ids[label] + bbox_json_results.append(data) + + # segm results + # some detectors use different scores for bbox and mask + if isinstance(seg, tuple): + segms = seg[0][label] + mask_score = seg[1][label] + else: + segms = seg[label] + mask_score = [bbox[4] for bbox in bboxes] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(mask_score[i]) + data['category_id'] = self.cat_ids[label] + if isinstance(segms[i]['counts'], bytes): + segms[i]['counts'] = segms[i]['counts'].decode() + data['segmentation'] = segms[i] + segm_json_results.append(data) + return bbox_json_results, segm_json_results + + def results2json(self, results, outfile_prefix): + """Dump the detection results to a COCO style json file. + + There are 3 types of results: proposals, bbox predictions, mask + predictions, and they have different data types. This method will + automatically recognize the type, and dump them to json files. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.bbox.json", "somepath/xxx.segm.json", + "somepath/xxx.proposal.json". + + Returns: + dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \ + values are corresponding filenames. + """ + result_files = dict() + if isinstance(results[0], list): + json_results = self._det2json(results) + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + dump(json_results, result_files['bbox']) + elif isinstance(results[0], tuple): + json_results = self._segm2json(results) + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + result_files['segm'] = f'{outfile_prefix}.segm.json' + dump(json_results[0], result_files['bbox']) + dump(json_results[1], result_files['segm']) + elif isinstance(results[0], np.ndarray): + json_results = self._proposal2json(results) + result_files['proposal'] = f'{outfile_prefix}.proposal.json' + dump(json_results, result_files['proposal']) + else: + raise TypeError('invalid type of results') + return result_files + + +def encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. + + Args: + mask_results (list | tuple[list]): bitmap mask results. + In mask scoring rcnn, mask_results is a tuple of (segm_results, + segm_cls_score). + + Returns: + list | tuple: RLE encoded mask. + """ + if isinstance(mask_results, tuple): # mask scoring + cls_segms, cls_mask_scores = mask_results + else: + cls_segms = mask_results + num_classes = len(cls_segms) + encoded_mask_results = [[] for _ in range(num_classes)] + for i in range(len(cls_segms)): + for cls_segm in cls_segms[i]: + encoded_mask_results[i].append( + mask_util.encode( + np.array( + cls_segm[:, :, np.newaxis], order='F', + dtype='uint8'))[0]) # encoded with RLE + if isinstance(mask_results, tuple): + return encoded_mask_results, cls_mask_scores + else: + return encoded_mask_results diff --git a/modelscope/models/cv/image_instance_segmentation/__init__.py b/modelscope/models/cv/image_instance_segmentation/__init__.py new file mode 100644 index 00000000..d069cfaf --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/__init__.py @@ -0,0 +1,2 @@ +from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin +from .model import CascadeMaskRCNNSwinModel diff --git a/modelscope/models/cv/image_instance_segmentation/backbones/__init__.py b/modelscope/models/cv/image_instance_segmentation/backbones/__init__.py new file mode 100644 index 00000000..c052cc50 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/backbones/__init__.py @@ -0,0 +1 @@ +from .swin_transformer import SwinTransformer diff --git a/modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py b/modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py new file mode 100644 index 00000000..3e7609e1 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py @@ -0,0 +1,694 @@ +# Modified from: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, + Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Inspiration from + https://github.com/SwinTransformer/Swin-Transformer-Object-Detection + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1] + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], + patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone.""" + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, + 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() diff --git a/modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py b/modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py new file mode 100644 index 00000000..30e70f82 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py @@ -0,0 +1,266 @@ +import os +from collections import OrderedDict + +import torch +import torch.distributed as dist +import torch.nn as nn + +from modelscope.models.cv.image_instance_segmentation.backbones import \ + SwinTransformer +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def build_backbone(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'SwinTransformer': + return SwinTransformer(**cfg) + else: + raise ValueError(f'backbone \'{type}\' is not supported.') + + +def build_neck(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'FPN': + from mmdet.models import FPN + return FPN(**cfg) + else: + raise ValueError(f'neck \'{type}\' is not supported.') + + +def build_rpn_head(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'RPNHead': + from mmdet.models import RPNHead + return RPNHead(**cfg) + else: + raise ValueError(f'rpn head \'{type}\' is not supported.') + + +def build_roi_head(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'CascadeRoIHead': + from mmdet.models import CascadeRoIHead + return CascadeRoIHead(**cfg) + else: + raise ValueError(f'roi head \'{type}\' is not supported.') + + +class CascadeMaskRCNNSwin(nn.Module): + + def __init__(self, + backbone, + neck, + rpn_head, + roi_head, + pretrained=None, + **kwargs): + """ + Args: + backbone (dict): backbone config. + neck (dict): neck config. + rpn_head (dict): rpn_head config. + roi_head (dict): roi_head config. + pretrained (bool): whether to use pretrained model + """ + super(CascadeMaskRCNNSwin, self).__init__() + + self.backbone = build_backbone(backbone) + self.neck = build_neck(neck) + self.rpn_head = build_rpn_head(rpn_head) + self.roi_head = build_roi_head(roi_head) + + self.classes = kwargs.pop('classes', None) + + if pretrained: + assert 'model_dir' in kwargs, 'pretrained model dir is missing.' + model_path = os.path.join(kwargs['model_dir'], + ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + weight = torch.load(model_path)['state_dict'] + tgt_weight = self.state_dict() + for name in list(weight.keys()): + if name in tgt_weight: + load_size = weight[name].size() + tgt_size = tgt_weight[name].size() + mis_match = False + if len(load_size) != len(tgt_size): + mis_match = True + else: + for n1, n2 in zip(load_size, tgt_size): + if n1 != n2: + mis_match = True + break + if mis_match: + logger.info(f'size mismatch for {name}, skip loading.') + del weight[name] + + self.load_state_dict(weight, strict=False) + logger.info('load model done') + + from mmcv.parallel import DataContainer, scatter + + self.data_container = DataContainer + self.scatter = scatter + + def extract_feat(self, img): + x = self.backbone(img) + x = self.neck(x) + return x + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + gt_masks=None, + proposals=None, + **kwargs): + """ + Args: + img (Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + + img_metas (list[dict]): list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmdet/datasets/pipelines/formatting.py:Collect`. + + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + + gt_labels (list[Tensor]): class indices corresponding to each box + + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + gt_masks (None | Tensor) : true segmentation masks for each box + used if the architecture supports a segmentation task. + + proposals : override rpn proposals with custom proposals. Use when + `with_rpn` is False. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self.extract_feat(img) + + losses = dict() + + # RPN forward and loss + proposal_cfg = self.rpn_head.train_cfg.get('rpn_proposal', + self.rpn_head.test_cfg) + rpn_losses, proposal_list = self.rpn_head.forward_train( + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=gt_bboxes_ignore, + proposal_cfg=proposal_cfg, + **kwargs) + losses.update(rpn_losses) + + roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, + gt_bboxes, gt_labels, + gt_bboxes_ignore, gt_masks, + **kwargs) + losses.update(roi_losses) + + return losses + + def forward_test(self, img, img_metas, proposals=None, rescale=True): + + x = self.extract_feat(img) + if proposals is None: + proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) + else: + proposal_list = proposals + + result = self.roi_head.simple_test( + x, proposal_list, img_metas, rescale=rescale) + return dict(eval_result=result, img_metas=img_metas) + + def forward(self, img, img_metas, **kwargs): + + # currently only support cpu or single gpu + if isinstance(img, self.data_container): + img = img.data[0] + if isinstance(img_metas, self.data_container): + img_metas = img_metas.data[0] + for k, w in kwargs.items(): + if isinstance(w, self.data_container): + w = w.data[0] + kwargs[k] = w + + if next(self.parameters()).is_cuda: + device = next(self.parameters()).device + img = self.scatter(img, [device])[0] + img_metas = self.scatter(img_metas, [device])[0] + for k, w in kwargs.items(): + kwargs[k] = self.scatter(w, [device])[0] + + if self.training: + losses = self.forward_train(img, img_metas, **kwargs) + loss, log_vars = self._parse_losses(losses) + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(img_metas)) + return outputs + else: + return self.forward_test(img, img_metas, **kwargs) + + def _parse_losses(self, losses): + + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def val_step(self, data, optimizer=None): + + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs diff --git a/modelscope/models/cv/image_instance_segmentation/datasets/__init__.py b/modelscope/models/cv/image_instance_segmentation/datasets/__init__.py new file mode 100644 index 00000000..93c71b46 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/datasets/__init__.py @@ -0,0 +1,2 @@ +from .dataset import ImageInstanceSegmentationCocoDataset +from .transforms import build_preprocess_transform diff --git a/modelscope/models/cv/image_instance_segmentation/datasets/dataset.py b/modelscope/models/cv/image_instance_segmentation/datasets/dataset.py new file mode 100644 index 00000000..d9e1b348 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/datasets/dataset.py @@ -0,0 +1,332 @@ +import os.path as osp + +import numpy as np +from pycocotools.coco import COCO +from torch.utils.data import Dataset + + +class ImageInstanceSegmentationCocoDataset(Dataset): + """Coco-style dataset for image instance segmentation. + + Args: + ann_file (str): Annotation file path. + classes (Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + data_root (str, optional): Data root for ``ann_file``, + ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. + test_mode (bool, optional): If set True, annotation will not be loaded. + filter_empty_gt (bool, optional): If set true, images without bounding + boxes of the dataset's classes will be filtered out. This option + only works when `test_mode=False`, i.e., we never filter images + during tests. + """ + + CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + + def __init__(self, + ann_file, + classes=None, + data_root=None, + img_prefix='', + seg_prefix=None, + test_mode=False, + filter_empty_gt=True): + self.ann_file = ann_file + self.data_root = data_root + self.img_prefix = img_prefix + self.seg_prefix = seg_prefix + self.test_mode = test_mode + self.filter_empty_gt = filter_empty_gt + self.CLASSES = self.get_classes(classes) + + # join paths if data_root is specified + if self.data_root is not None: + if not osp.isabs(self.ann_file): + self.ann_file = osp.join(self.data_root, self.ann_file) + if not (self.img_prefix is None or osp.isabs(self.img_prefix)): + self.img_prefix = osp.join(self.data_root, self.img_prefix) + if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): + self.seg_prefix = osp.join(self.data_root, self.seg_prefix) + + # load annotations + self.data_infos = self.load_annotations(self.ann_file) + + # filter images too small and containing no annotations + if not test_mode: + valid_inds = self._filter_imgs() + self.data_infos = [self.data_infos[i] for i in valid_inds] + # set group flag for the sampler + self._set_group_flag() + + self.preprocessor = None + + def __len__(self): + """Total number of samples of data.""" + return len(self.data_infos) + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + + Returns: + list[dict]: Annotation info from COCO api. + """ + + self.coco = COCO(ann_file) + # The order of returned `cat_ids` will not + # change with the order of the CLASSES + self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES) + + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.getImgIds() + data_infos = [] + total_ann_ids = [] + for i in self.img_ids: + info = self.coco.loadImgs([i])[0] + info['filename'] = info['file_name'] + info['ann_file'] = ann_file + info['classes'] = self.CLASSES + data_infos.append(info) + ann_ids = self.coco.getAnnIds(imgIds=[i]) + total_ann_ids.extend(ann_ids) + assert len(set(total_ann_ids)) == len( + total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" + return data_infos + + def get_ann_info(self, idx): + """Get COCO annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + img_id = self.data_infos[idx]['id'] + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + ann_info = self.coco.loadAnns(ann_ids) + return self._parse_ann_info(self.data_infos[idx], ann_info) + + def get_cat_ids(self, idx): + """Get COCO category ids by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + img_id = self.data_infos[idx]['id'] + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + ann_info = self.coco.loadAnns(ann_ids) + return [ann['category_id'] for ann in ann_info] + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['img_prefix'] = self.img_prefix + results['seg_prefix'] = self.seg_prefix + results['bbox_fields'] = [] + results['mask_fields'] = [] + results['seg_fields'] = [] + + def _filter_imgs(self, min_size=32): + """Filter images too small or without ground truths.""" + valid_inds = [] + # obtain images that contain annotation + ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.coco.catToImgs[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_img_ids = [] + for i, img_info in enumerate(self.data_infos): + img_id = self.img_ids[i] + if self.filter_empty_gt and img_id not in ids_in_cat: + continue + if min(img_info['width'], img_info['height']) >= min_size: + valid_inds.append(i) + valid_img_ids.append(img_id) + self.img_ids = valid_img_ids + return valid_inds + + def _parse_ann_info(self, img_info, ann_info): + """Parse bbox and mask annotation. + + Args: + ann_info (list[dict]): Annotation info of an image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore,\ + labels, masks, seg_map. "masks" are raw annotations and not \ + decoded into binary masks. + """ + gt_bboxes = [] + gt_labels = [] + gt_bboxes_ignore = [] + gt_masks_ann = [] + for i, ann in enumerate(ann_info): + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(bbox) + else: + gt_bboxes.append(bbox) + gt_labels.append(self.cat2label[ann['category_id']]) + gt_masks_ann.append(ann.get('segmentation', None)) + + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + seg_map = img_info['filename'].replace('jpg', 'png') + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks=gt_masks_ann, + seg_map=seg_map) + + return ann + + def _set_group_flag(self): + """Set flag according to image aspect ratio. + + Images with aspect ratio greater than 1 will be set as group 1, + otherwise group 0. + """ + self.flag = np.zeros(len(self), dtype=np.uint8) + for i in range(len(self)): + img_info = self.data_infos[i] + if img_info['width'] / img_info['height'] > 1: + self.flag[i] = 1 + + def _rand_another(self, idx): + """Get another random index from the same group as the given index.""" + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set \ + True). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + while True: + data = self.prepare_train_img(idx) + if data is None: + idx = self._rand_another(idx) + continue + return data + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + + img_info = self.data_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + self.pre_pipeline(results) + if self.preprocessor is None: + return results + self.preprocessor.train() + return self.preprocessor(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by \ + pipeline. + """ + + img_info = self.data_infos[idx] + results = dict(img_info=img_info) + self.pre_pipeline(results) + if self.preprocessor is None: + return results + self.preprocessor.eval() + results = self.preprocessor(results) + return results + + @classmethod + def get_classes(cls, classes=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is + a tuple or list, override the CLASSES defined by the dataset. + + Returns: + tuple[str] or list[str]: Names of categories of the dataset. + """ + if classes is None: + return cls.CLASSES + + if isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + return class_names + + def to_torch_dataset(self, preprocessors=None): + self.preprocessor = preprocessors + return self diff --git a/modelscope/models/cv/image_instance_segmentation/datasets/transforms.py b/modelscope/models/cv/image_instance_segmentation/datasets/transforms.py new file mode 100644 index 00000000..abc30c77 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/datasets/transforms.py @@ -0,0 +1,109 @@ +import os.path as osp + +import numpy as np + +from modelscope.fileio import File + + +def build_preprocess_transform(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'LoadImageFromFile': + return LoadImageFromFile(**cfg) + elif type == 'LoadAnnotations': + from mmdet.datasets.pipelines import LoadAnnotations + return LoadAnnotations(**cfg) + elif type == 'Resize': + if 'img_scale' in cfg: + if isinstance(cfg.img_scale[0], list): + elems = [] + for elem in cfg.img_scale: + elems.append(tuple(elem)) + cfg.img_scale = elems + else: + cfg.img_scale = tuple(cfg.img_scale) + from mmdet.datasets.pipelines import Resize + return Resize(**cfg) + elif type == 'RandomFlip': + from mmdet.datasets.pipelines import RandomFlip + return RandomFlip(**cfg) + elif type == 'Normalize': + from mmdet.datasets.pipelines import Normalize + return Normalize(**cfg) + elif type == 'Pad': + from mmdet.datasets.pipelines import Pad + return Pad(**cfg) + elif type == 'DefaultFormatBundle': + from mmdet.datasets.pipelines import DefaultFormatBundle + return DefaultFormatBundle(**cfg) + elif type == 'ImageToTensor': + from mmdet.datasets.pipelines import ImageToTensor + return ImageToTensor(**cfg) + elif type == 'Collect': + from mmdet.datasets.pipelines import Collect + return Collect(**cfg) + else: + raise ValueError(f'preprocess transform \'{type}\' is not supported.') + + +class LoadImageFromFile: + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, to_float32=False, mode='rgb'): + self.to_float32 = to_float32 + self.mode = mode + + from mmcv import imfrombytes + + self.imfrombytes = imfrombytes + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`ImageInstanceSegmentationDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if results['img_prefix'] is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + + img_bytes = File.read(filename) + + img = self.imfrombytes(img_bytes, 'color', 'bgr', backend='pillow') + + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + results['ann_file'] = results['img_info']['ann_file'] + results['classes'] = results['img_info']['classes'] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32}, ' + f"mode='{self.mode}'") + return repr_str diff --git a/modelscope/models/cv/image_instance_segmentation/model.py b/modelscope/models/cv/image_instance_segmentation/model.py new file mode 100644 index 00000000..2be59623 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/model.py @@ -0,0 +1,49 @@ +import os +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.image_instance_segmentation import \ + CascadeMaskRCNNSwin +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + + +@MODELS.register_module( + Tasks.image_segmentation, module_name=Models.cascade_mask_rcnn_swin) +class CascadeMaskRCNNSwinModel(TorchModel): + + def __init__(self, model_dir=None, *args, **kwargs): + """ + Args: + model_dir (str): model directory. + + """ + super(CascadeMaskRCNNSwinModel, self).__init__( + model_dir=model_dir, *args, **kwargs) + + if 'backbone' not in kwargs: + config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + model_cfg = cfg.model + kwargs.update(model_cfg) + + self.model = CascadeMaskRCNNSwin(model_dir=model_dir, **kwargs) + + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(self.device) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = self.model(**input) + return output + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + + return input + + def compute_loss(self, outputs: Dict[str, Any], labels): + pass diff --git a/modelscope/models/cv/image_instance_segmentation/postprocess_utils.py b/modelscope/models/cv/image_instance_segmentation/postprocess_utils.py new file mode 100644 index 00000000..43e52292 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/postprocess_utils.py @@ -0,0 +1,203 @@ +import itertools + +import cv2 +import numpy as np +import pycocotools.mask as maskUtils +import torch + +from modelscope.outputs import OutputKeys + + +def get_seg_bboxes(bboxes, labels, segms=None, class_names=None, score_thr=0.): + assert bboxes.ndim == 2, \ + f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' + assert labels.ndim == 1, \ + f' labels ndim should be 1, but its ndim is {labels.ndim}.' + assert bboxes.shape[0] == labels.shape[0], \ + 'bboxes.shape[0] and labels.shape[0] should have the same length.' + assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \ + f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.' + + if score_thr > 0: + assert bboxes.shape[1] == 5 + scores = bboxes[:, -1] + inds = scores > score_thr + bboxes = bboxes[inds, :] + labels = labels[inds] + if segms is not None: + segms = segms[inds, ...] + + bboxes_names = [] + for i, (bbox, label) in enumerate(zip(bboxes, labels)): + label_name = class_names[ + label] if class_names is not None else f'class {label}' + bbox = [0 if b < 0 else b for b in list(bbox)] + bbox.append(label_name) + bbox.append(segms[i].astype(bool)) + bboxes_names.append(bbox) + + return bboxes_names + + +def get_img_seg_results(det_rawdata=None, + class_names=None, + score_thr=0.3, + is_decode=True): + ''' + Get all boxes of one image. + score_thr: Classification probability threshold。 + output format: [ [x1,y1,x2,y2, prob, cls_name, mask], [x1,y1,x2,y2, prob, cls_name, mask], ... ] + ''' + assert det_rawdata is not None, 'det_rawdata should be not None.' + assert class_names is not None, 'class_names should be not None.' + + if isinstance(det_rawdata, tuple): + bbox_result, segm_result = det_rawdata + if isinstance(segm_result, tuple): + segm_result = segm_result[0] # ms rcnn + else: + bbox_result, segm_result = det_rawdata, None + bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + segms = None + if segm_result is not None and len(labels) > 0: # non empty + segms = list(itertools.chain(*segm_result)) + if is_decode: + segms = maskUtils.decode(segms) + segms = segms.transpose(2, 0, 1) + if isinstance(segms[0], torch.Tensor): + segms = torch.stack(segms, dim=0).detach().cpu().numpy() + else: + segms = np.stack(segms, axis=0) + + bboxes_names = get_seg_bboxes( + bboxes, + labels, + segms=segms, + class_names=class_names, + score_thr=score_thr) + + return bboxes_names + + +def get_img_ins_seg_result(img_seg_result=None, + class_names=None, + score_thr=0.3): + assert img_seg_result is not None, 'img_seg_result should be not None.' + assert class_names is not None, 'class_names should be not None.' + + img_seg_result = get_img_seg_results( + det_rawdata=(img_seg_result[0], img_seg_result[1]), + class_names=class_names, + score_thr=score_thr, + is_decode=False) + + results_dict = { + OutputKeys.BOXES: [], + OutputKeys.MASKS: [], + OutputKeys.LABELS: [], + OutputKeys.SCORES: [] + } + for seg_result in img_seg_result: + + box = { + 'x': np.int(seg_result[0]), + 'y': np.int(seg_result[1]), + 'w': np.int(seg_result[2] - seg_result[0]), + 'h': np.int(seg_result[3] - seg_result[1]) + } + score = np.float(seg_result[4]) + category = seg_result[5] + + mask = np.array(seg_result[6], order='F', dtype='uint8') + mask = mask.astype(np.float) + + results_dict[OutputKeys.BOXES].append(box) + results_dict[OutputKeys.MASKS].append(mask) + results_dict[OutputKeys.SCORES].append(score) + results_dict[OutputKeys.LABELS].append(category) + + return results_dict + + +def show_result( + img, + result, + out_file='result.jpg', + show_box=True, + show_label=True, + show_score=True, + alpha=0.5, + fontScale=0.5, + fontFace=cv2.FONT_HERSHEY_COMPLEX_SMALL, + thickness=1, +): + + assert isinstance(img, (str, np.ndarray)), \ + f'img must be str or np.ndarray, but got {type(img)}.' + + if isinstance(img, str): + img = cv2.imread(img) + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = img.astype(np.float32) + + labels = result[OutputKeys.LABELS] + scores = result[OutputKeys.SCORES] + boxes = result[OutputKeys.BOXES] + masks = result[OutputKeys.MASKS] + + for label, score, box, mask in zip(labels, scores, boxes, masks): + + random_color = np.array([ + np.random.random() * 255.0, + np.random.random() * 255.0, + np.random.random() * 255.0 + ]) + + x1 = int(box['x']) + y1 = int(box['y']) + w = int(box['w']) + h = int(box['h']) + x2 = x1 + w + y2 = y1 + h + + if show_box: + cv2.rectangle( + img, (x1, y1), (x2, y2), random_color, thickness=thickness) + if show_label or show_score: + if show_label and show_score: + text = '{}|{}'.format(label, round(float(score), 2)) + elif show_label: + text = '{}'.format(label) + else: + text = '{}'.format(round(float(score), 2)) + + retval, baseLine = cv2.getTextSize( + text, + fontFace=fontFace, + fontScale=fontScale, + thickness=thickness) + cv2.rectangle( + img, (x1, y1 - retval[1] - baseLine), (x1 + retval[0], y1), + thickness=-1, + color=(0, 0, 0)) + cv2.putText( + img, + text, (x1, y1 - baseLine), + fontScale=fontScale, + fontFace=fontFace, + thickness=thickness, + color=random_color) + + idx = np.nonzero(mask) + img[idx[0], idx[1], :] *= 1.0 - alpha + img[idx[0], idx[1], :] += alpha * random_color + + cv2.imwrite(out_file, img) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index f3a40824..da770b70 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -13,6 +13,7 @@ class OutputKeys(object): POSES = 'poses' CAPTION = 'caption' BOXES = 'boxes' + MASKS = 'masks' TEXT = 'text' POLYGONS = 'polygons' OUTPUT = 'output' diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 072a5a00..58730d9a 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -76,6 +76,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_daflow_virtual-tryon_base'), Tasks.image_colorization: (Pipelines.image_colorization, 'damo/cv_unet_image-colorization'), + Tasks.image_segmentation: + (Pipelines.image_instance_segmentation, + 'damo/cv_swin-b_image-instance-segmentation_coco'), Tasks.style_transfer: (Pipelines.style_transfer, 'damo/cv_aams_style-transfer_damo'), Tasks.face_image_generation: (Pipelines.face_image_generation, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 006cb92c..85453fef 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -11,6 +11,7 @@ try: from .image_colorization_pipeline import ImageColorizationPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline + from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": pass diff --git a/modelscope/pipelines/cv/image_instance_segmentation_pipeline.py b/modelscope/pipelines/cv/image_instance_segmentation_pipeline.py new file mode 100644 index 00000000..1034fb64 --- /dev/null +++ b/modelscope/pipelines/cv/image_instance_segmentation_pipeline.py @@ -0,0 +1,105 @@ +import os +from typing import Any, Dict, Optional, Union + +import cv2 +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_instance_segmentation.model import \ + CascadeMaskRCNNSwinModel +from modelscope.models.cv.image_instance_segmentation.postprocess_utils import \ + get_img_ins_seg_result +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (ImageInstanceSegmentationPreprocessor, + build_preprocessor, load_image) +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_segmentation, + module_name=Pipelines.image_instance_segmentation) +class ImageInstanceSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[CascadeMaskRCNNSwinModel, str], + preprocessor: Optional[ + ImageInstanceSegmentationPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a image instance segmentation pipeline for prediction + + Args: + model (CascadeMaskRCNNSwinModel | str): a model instance + preprocessor (CascadeMaskRCNNSwinPreprocessor | None): a preprocessor instance + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + if preprocessor is None: + config_path = os.path.join(self.model.model_dir, + ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + self.preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) + else: + self.preprocessor = preprocessor + + self.preprocessor.eval() + self.model.eval() + + def _collate_fn(self, data): + # don't require collating + return data + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + filename = None + img = None + if isinstance(input, str): + filename = input + img = np.array(load_image(input)) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, Image.Image): + img = np.array(input.convert('RGB')) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + result = { + 'img': img, + 'img_shape': img.shape, + 'ori_shape': img.shape, + 'img_fields': ['img'], + 'img_prefix': '', + 'img_info': { + 'filename': filename, + 'ann_file': None, + 'classes': None + }, + } + result = self.preprocessor(result) + + # stacked as a batch + result['img'] = torch.stack([result['img']], dim=0) + result['img_metas'] = [result['img_metas'].data] + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + output = self.model(input) + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + result = get_img_ins_seg_result( + img_seg_result=inputs['eval_result'][0], + class_names=self.model.model.classes) + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 38b67276..c3cbfb4f 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -20,6 +20,7 @@ try: from .space.dialog_modeling_preprocessor import * # noqa F403 from .space.dialog_state_tracking_preprocessor import * # noqa F403 from .image import ImageColorEnhanceFinetunePreprocessor + from .image import ImageInstanceSegmentationPreprocessor except ModuleNotFoundError as e: if str(e) == "No module named 'tensorflow'": print(TENSORFLOW_IMPORT_ERROR.format('tts')) diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index 4c911f97..f6f93319 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -136,3 +136,72 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor): """ return data + + +@PREPROCESSORS.register_module( + Fields.cv, + module_name=Preprocessors.image_instance_segmentation_preprocessor) +class ImageInstanceSegmentationPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """image instance segmentation preprocessor in the fine-tune scenario + """ + + super().__init__(*args, **kwargs) + + self.training = kwargs.pop('training', True) + self.preprocessor_train_cfg = kwargs.pop('train', None) + self.preprocessor_test_cfg = kwargs.pop('val', None) + + self.train_transforms = [] + self.test_transforms = [] + + from modelscope.models.cv.image_instance_segmentation.datasets import \ + build_preprocess_transform + + if self.preprocessor_train_cfg is not None: + if isinstance(self.preprocessor_train_cfg, dict): + self.preprocessor_train_cfg = [self.preprocessor_train_cfg] + for cfg in self.preprocessor_train_cfg: + transform = build_preprocess_transform(cfg) + self.train_transforms.append(transform) + + if self.preprocessor_test_cfg is not None: + if isinstance(self.preprocessor_test_cfg, dict): + self.preprocessor_test_cfg = [self.preprocessor_test_cfg] + for cfg in self.preprocessor_test_cfg: + transform = build_preprocess_transform(cfg) + self.test_transforms.append(transform) + + def train(self): + self.training = True + return + + def eval(self): + self.training = False + return + + @type_assert(object, object) + def __call__(self, results: Dict[str, Any]): + """process the raw input data + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + Dict[str, Any] | None: the preprocessed data + """ + + if self.training: + transforms = self.train_transforms + else: + transforms = self.test_transforms + + for t in transforms: + + results = t(results) + + if results is None: + return None + + return results diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index 74d96e59..e5dde881 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -1,4 +1,5 @@ from .base import DummyTrainer from .builder import build_trainer +from .cv import ImageInstanceSegmentationTrainer from .nlp import SequenceClassificationTrainer from .trainer import EpochBasedTrainer diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py new file mode 100644 index 00000000..07b1646d --- /dev/null +++ b/modelscope/trainers/cv/__init__.py @@ -0,0 +1,2 @@ +from .image_instance_segmentation_trainer import \ + ImageInstanceSegmentationTrainer diff --git a/modelscope/trainers/cv/image_instance_segmentation_trainer.py b/modelscope/trainers/cv/image_instance_segmentation_trainer.py new file mode 100644 index 00000000..aa8cc9e3 --- /dev/null +++ b/modelscope/trainers/cv/image_instance_segmentation_trainer.py @@ -0,0 +1,27 @@ +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer + + +@TRAINERS.register_module(module_name='image-instance-segmentation') +class ImageInstanceSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def collate_fn(self, data): + # we skip this func due to some special data type, e.g., BitmapMasks + return data + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass + + def to_task_dataset(self, datasets, mode, preprocessor=None): + # wait for dataset interface to become stable... + return datasets.to_torch_dataset(preprocessor) diff --git a/tests/pipelines/test_image_instance_segmentation.py b/tests/pipelines/test_image_instance_segmentation.py new file mode 100644 index 00000000..37b92266 --- /dev/null +++ b/tests/pipelines/test_image_instance_segmentation.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.cv.image_instance_segmentation.model import \ + CascadeMaskRCNNSwinModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines import ImageInstanceSegmentationPipeline, pipeline +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class ImageInstanceSegmentationTest(unittest.TestCase): + model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' + image = 'data/test/images/image_instance_segmentation.jpg' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) + pipeline_ins = pipeline( + task=Tasks.image_segmentation, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.image_segmentation, model=self.model_id) + print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.image_segmentation) + print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) + model = CascadeMaskRCNNSwinModel(cache_path) + pipeline1 = ImageInstanceSegmentationPipeline( + model, preprocessor=preprocessor) + pipeline2 = pipeline( + Tasks.image_segmentation, model=model, preprocessor=preprocessor) + print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}') + print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_instance_segmentation_trainer.py b/tests/trainers/test_image_instance_segmentation_trainer.py new file mode 100644 index 00000000..fb9eb3c4 --- /dev/null +++ b/tests/trainers/test_image_instance_segmentation_trainer.py @@ -0,0 +1,117 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile +from functools import partial + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.image_instance_segmentation import \ + CascadeMaskRCNNSwinModel +from modelscope.models.cv.image_instance_segmentation.datasets import \ + ImageInstanceSegmentationCocoDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + data_root = cfg.dataset.data_root + classes = tuple(cfg.dataset.classes) + max_epochs = cfg.train.max_epochs + samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu + + if data_root is None: + # use default toy data + dataset_path = os.path.join(cache_path, 'toydata.zip') + with zipfile.ZipFile(dataset_path, 'r') as zipf: + zipf.extractall(cache_path) + data_root = cache_path + '/toydata/' + classes = ('Cat', 'Dog') + + self.train_dataset = ImageInstanceSegmentationCocoDataset( + data_root + 'annotations/instances_train.json', + classes=classes, + data_root=data_root, + img_prefix=data_root + 'images/train/', + seg_prefix=None, + test_mode=False) + + self.eval_dataset = ImageInstanceSegmentationCocoDataset( + data_root + 'annotations/instances_val.json', + classes=classes, + data_root=data_root, + img_prefix=data_root + 'images/val/', + seg_prefix=None, + test_mode=True) + + from mmcv.parallel import collate + + self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu) + + self.max_epochs = max_epochs + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + data_collator=self.collate_fn, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name='image-instance-segmentation', default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + data_collator=self.collate_fn, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name='image-instance-segmentation', default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()