diff --git a/data/test/images/image_semantic_segmentation.jpg b/data/test/images/image_semantic_segmentation.jpg new file mode 100644 index 00000000..2a8d826b --- /dev/null +++ b/data/test/images/image_semantic_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a +size 245864 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 1fba50b3..8e21c00b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -23,6 +23,8 @@ class Models(object): panoptic_segmentation = 'swinL-panoptic-segmentation' image_reid_person = 'passvitb' video_summarization = 'pgl-video-summarization' + swinL_semantic_segmentation = 'swinL-semantic-segmentation' + vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' # nlp models bert = 'bert' @@ -117,6 +119,7 @@ class Pipelines(object): video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' image_panoptic_segmentation = 'image-panoptic-segmentation' video_summarization = 'googlenet_pgl_video_summarization' + image_semantic_segmentation = 'image-semantic-segmentation' image_reid_person = 'passvitb-image-reid-person' # nlp tasks diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 3af7a1b6..227be2c7 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -4,8 +4,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, image_panoptic_segmentation, image_portrait_enhancement, - image_reid_person, image_to_image_generation, - image_to_image_translation, object_detection, - product_retrieval_embedding, salient_detection, - super_resolution, video_single_object_tracking, - video_summarization, virual_tryon) + image_reid_person, image_semantic_segmentation, + image_to_image_generation, image_to_image_translation, + object_detection, product_retrieval_embedding, + salient_detection, super_resolution, + video_single_object_tracking, video_summarization, virual_tryon) diff --git a/modelscope/models/cv/image_semantic_segmentation/__init__.py b/modelscope/models/cv/image_semantic_segmentation/__init__.py new file mode 100644 index 00000000..598d7c21 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .semantic_seg_model import SemanticSegmentation + +else: + _import_structure = { + 'semantic_seg_model': ['SemanticSegmentation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_semantic_segmentation/pan_merge/__init__.py b/modelscope/models/cv/image_semantic_segmentation/pan_merge/__init__.py new file mode 100644 index 00000000..2a75f318 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/pan_merge/__init__.py @@ -0,0 +1 @@ +from .maskformer_semantic_head import MaskFormerSemanticHead diff --git a/modelscope/models/cv/image_semantic_segmentation/pan_merge/base_panoptic_fusion_head.py b/modelscope/models/cv/image_semantic_segmentation/pan_merge/base_panoptic_fusion_head.py new file mode 100644 index 00000000..05e68d89 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/pan_merge/base_panoptic_fusion_head.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmcv.runner import BaseModule +from mmdet.models.builder import build_loss + + +class BasePanopticFusionHead(BaseModule, metaclass=ABCMeta): + """Base class for panoptic heads.""" + + def __init__(self, + num_things_classes=80, + num_stuff_classes=53, + test_cfg=None, + loss_panoptic=None, + init_cfg=None, + **kwargs): + super(BasePanopticFusionHead, self).__init__(init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = num_things_classes + num_stuff_classes + self.test_cfg = test_cfg + + if loss_panoptic: + self.loss_panoptic = build_loss(loss_panoptic) + else: + self.loss_panoptic = None + + @property + def with_loss(self): + """bool: whether the panoptic head contains loss function.""" + return self.loss_panoptic is not None + + @abstractmethod + def forward_train(self, gt_masks=None, gt_semantic_seg=None, **kwargs): + """Forward function during training.""" + + @abstractmethod + def simple_test(self, + img_metas, + det_labels, + mask_preds, + seg_preds, + det_bboxes, + cfg=None, + **kwargs): + """Test without augmentation.""" diff --git a/modelscope/models/cv/image_semantic_segmentation/pan_merge/maskformer_semantic_head.py b/modelscope/models/cv/image_semantic_segmentation/pan_merge/maskformer_semantic_head.py new file mode 100644 index 00000000..6769ebaf --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/pan_merge/maskformer_semantic_head.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from mmdet.models.builder import HEADS + +from .base_panoptic_fusion_head import BasePanopticFusionHead + + +@HEADS.register_module() +class MaskFormerSemanticHead(BasePanopticFusionHead): + + def __init__(self, + num_things_classes=80, + num_stuff_classes=53, + test_cfg=None, + loss_panoptic=None, + init_cfg=None, + **kwargs): + super().__init__(num_things_classes, num_stuff_classes, test_cfg, + loss_panoptic, init_cfg, **kwargs) + + def forward_train(self, **kwargs): + """MaskFormerFusionHead has no training loss.""" + return dict() + + def simple_test(self, + mask_cls_results, + mask_pred_results, + img_metas, + rescale=False, + **kwargs): + results = [] + for mask_cls_result, mask_pred_result, meta in zip( + mask_cls_results, mask_pred_results, img_metas): + # remove padding + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_result[:, :img_height, :img_width] + + if rescale: + # return result in original resolution + ori_height, ori_width = meta['ori_shape'][:2] + mask_pred_result = F.interpolate( + mask_pred_result[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + + # semantic inference + cls_score = F.softmax(mask_cls_result, dim=-1)[..., :-1] + mask_pred = mask_pred_result.sigmoid() + seg_mask = torch.einsum('qc,qhw->chw', cls_score, mask_pred) + # still need softmax and argmax + seg_logit = F.softmax(seg_mask, dim=0) + seg_pred = seg_logit.argmax(dim=0) + seg_pred = seg_pred.cpu().numpy() + results.append(seg_pred) + + return results diff --git a/modelscope/models/cv/image_semantic_segmentation/semantic_seg_model.py b/modelscope/models/cv/image_semantic_segmentation/semantic_seg_model.py new file mode 100644 index 00000000..60acf28f --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/semantic_seg_model.py @@ -0,0 +1,76 @@ +import os.path as osp + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.image_semantic_segmentation import (pan_merge, + vit_adapter) +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks + + +@MODELS.register_module( + Tasks.image_segmentation, module_name=Models.swinL_semantic_segmentation) +@MODELS.register_module( + Tasks.image_segmentation, + module_name=Models.vitadapter_semantic_segmentation) +class SemanticSegmentation(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, **kwargs) + + from mmcv.runner import load_checkpoint + import mmcv + from mmdet.models import build_detector + + config = osp.join(model_dir, 'mmcv_config.py') + cfg = mmcv.Config.fromfile(config) + if 'pretrained' in cfg.model: + cfg.model.pretrained = None + elif 'init_cfg' in cfg.model.backbone: + cfg.model.backbone.init_cfg = None + + # build model + cfg.model.train_cfg = None + self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + + # load model + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + _ = load_checkpoint(self.model, model_path, map_location='cpu') + + self.CLASSES = cfg['CLASSES'] # list + self.PALETTE = cfg['PALETTE'] # list + + self.num_classes = len(self.CLASSES) + self.cfg = cfg + + def forward(self, Inputs): + return self.model(**Inputs) + + def postprocess(self, Inputs): + semantic_result = Inputs[0] + + ids = np.unique(semantic_result)[::-1] + legal_indices = ids != self.model.num_classes # for VOID label + ids = ids[legal_indices] + + segms = (semantic_result[None] == ids[:, None, None]) + masks = [it.astype(np.int) for it in segms] + labels_txt = np.array(self.CLASSES)[ids].tolist() + + results = { + OutputKeys.MASKS: masks, + OutputKeys.LABELS: labels_txt, + OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] + } + return results + + def inference(self, data): + with torch.no_grad(): + results = self.model(return_loss=False, rescale=True, **data) + + return results diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/__init__.py new file mode 100644 index 00000000..82eec1c6 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/__init__.py @@ -0,0 +1,3 @@ +from .models import backbone, decode_heads, segmentors +from .utils import (ResizeToMultiple, add_prefix, build_pixel_sampler, + seg_resize) diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/__init__.py new file mode 100644 index 00000000..ae5c5acf --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/__init__.py @@ -0,0 +1,3 @@ +from .backbone import BASEBEiT, BEiTAdapter +from .decode_heads import Mask2FormerHeadFromMMSeg +from .segmentors import EncoderDecoderMask2Former diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/__init__.py new file mode 100644 index 00000000..ab4258c1 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/__init__.py @@ -0,0 +1,4 @@ +from .base import BASEBEiT +from .beit_adapter import BEiTAdapter + +__all__ = ['BEiTAdapter', 'BASEBEiT'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/adapter_modules.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/adapter_modules.py new file mode 100644 index 00000000..03080342 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/adapter_modules.py @@ -0,0 +1,523 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git + +import logging +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmdet.models.utils.transformer import MultiScaleDeformableAttention +from timm.models.layers import DropPath + +_logger = logging.getLogger(__name__) + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor([(h // 8, w // 8), (h // 16, w // 16), + (h // 32, w // 32)], + dtype=torch.long, + device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 16, w // 16)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // 16, w // 16)], + dtype=torch.long, + device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), + (h // 16, w // 16), + (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, + W * 2).contiguous() + x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, + W).contiguous() + x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, + W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0., + drop_path=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MultiScaleDeformableAttention( + embed_dims=dim, + num_heads=num_heads, + num_levels=n_levels, + num_points=n_points, + batch_first=True) + + # modify to fit the deform_ratio + value_proj_in_features = self.attn.value_proj.weight.shape[0] + value_proj_out_features = int(value_proj_in_features * deform_ratio) + self.attn.value_proj = nn.Linear(value_proj_in_features, + value_proj_out_features) + self.attn.output_proj = nn.Linear(value_proj_out_features, + value_proj_in_features) + + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN( + in_features=dim, + hidden_features=int(dim * cffn_ratio), + drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index, H, W): + + def _inner_forward(query, feat): + attn = self.attn( + query=self.query_norm(query), + key=None, + value=self.feat_norm(feat), + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index) + + query = query + attn + + if self.with_cffn: + query = query + self.drop_path( + self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class Injector(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0., + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MultiScaleDeformableAttention( + embed_dims=dim, + num_heads=num_heads, + num_levels=n_levels, + num_points=n_points, + batch_first=True) + + # modify to fit the deform_ratio + value_proj_in_features = self.attn.value_proj.weight.shape[0] + value_proj_out_features = int(value_proj_in_features * deform_ratio) + self.attn.value_proj = nn.Linear(value_proj_in_features, + value_proj_out_features) + self.attn.output_proj = nn.Linear(value_proj_out_features, + value_proj_in_features) + + self.gamma = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index): + + def _inner_forward(query, feat): + input_query = self.query_norm(query) + input_value = self.feat_norm(feat) + attn = self.attn( + query=input_query, + key=None, + value=input_value, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index) + return query + self.gamma * attn + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlock(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False, + with_cp=False): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + for idx, blk in enumerate(blocks): + x = blk(x, H, W) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + return x, c + + +class InteractionBlockWithCls(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False, + with_cp=False): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H, W) + cls, x = x[:, :1, ], x[:, 1:, ] + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential(*[ + nn.Conv2d( + 3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d( + inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d( + inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ]) + self.conv2 = nn.Sequential(*[ + nn.Conv2d( + inplanes, + 2 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv3 = nn.Sequential(*[ + nn.Conv2d( + 2 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv4 = nn.Sequential(*[ + nn.Conv2d( + 4 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.fc1 = nn.Conv2d( + inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d( + 2 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc3 = nn.Conv2d( + 4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc4 = nn.Conv2d( + 4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def forward(self, x): + + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/__init__.py new file mode 100644 index 00000000..40b0fa89 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/__init__.py @@ -0,0 +1,3 @@ +from .beit import BASEBEiT + +__all__ = ['BASEBEiT'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/beit.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/beit.py new file mode 100644 index 00000000..a5811fb9 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/beit.py @@ -0,0 +1,476 @@ +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# This implementation refers to +# https://github.com/czczup/ViT-Adapter.git +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.runner import _load_checkpoint +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # commit dropout for the original BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + window_size=None, + attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] + - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, + coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum( + -1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + self.register_buffer('relative_position_index', + relative_position_index) + + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if init_values is not None: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, H, W, rel_pos_bias=None): + + def _inner_forward(x): + if self.gamma_1 is None: + x = x + self.drop_path( + self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn( + self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * ( + img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, Hp, Wp + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_chans=3, + embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] + - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, + 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +@BACKBONES.register_module() +class BASEBEiT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=512, + patch_size=16, + in_chans=3, + num_classes=80, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + hybrid_backbone=None, + norm_layer=None, + init_values=None, + use_checkpoint=False, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + pretrained=None, + with_cp=False): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.norm_layer = norm_layer + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.drop_path_rate = drop_path_rate + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + with_cp=with_cp, + init_values=init_values, + window_size=self.patch_embed.patch_shape + if use_rel_pos_bias else None) for i in range(depth) + ]) + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + checkpoint = _load_checkpoint( + init_cfg['checkpoint'], logger=logger, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + self.load_state_dict(state_dict, False) + + def fix_init_weight(self): + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/beit_adapter.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/beit_adapter.py new file mode 100644 index 00000000..02a4968e --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/beit_adapter.py @@ -0,0 +1,169 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models.builder import BACKBONES +from mmdet.models.utils.transformer import MultiScaleDeformableAttention +from timm.models.layers import DropPath, trunc_normal_ +from torch.nn.init import normal_ + +from .adapter_modules import InteractionBlockWithCls as InteractionBlock +from .adapter_modules import SpatialPriorModule, deform_inputs +from .base.beit import BASEBEiT + +_logger = logging.getLogger(__name__) + + +@BACKBONES.register_module() +class BEiTAdapter(BASEBEiT): + + def __init__(self, + pretrain_size=224, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0., + cffn_ratio=0.25, + deform_ratio=1.0, + with_cffn=True, + interaction_indexes=None, + add_vit_feature=True, + with_cp=False, + *args, + **kwargs): + + super().__init__( + init_values=init_values, with_cp=with_cp, *args, **kwargs) + + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.flags = [ + i for i in range(-1, self.num_block, self.num_block // 4) + ][1:] + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule( + inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential(*[ + InteractionBlock( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=True if i == len(interaction_indexes) + - 1 else False, + with_cp=with_cp) for i in range(len(interaction_indexes)) + ]) + + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, + self.pretrain_size[1] // 16, + -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ + reshape(1, -1, H * W).permute(0, 2, 1) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + x, H, W = self.patch_embed(x) + bs, n, dim = x.shape + cls = self.cls_token.expand( + bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + + if self.pos_embed is not None: + pos_embed = self._get_pos_embed(self.pos_embed, H, W) + x = x + pos_embed + x = self.pos_drop(x) + + # Interaction + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + x, c, cls = layer(x, c, cls, + self.blocks[indexes[0]:indexes[-1] + 1], + deform_inputs1, deform_inputs2, H, W) + outs.append(x.transpose(1, 2).view(bs, dim, H, W).contiguous()) + + # Split & Reshape + c2 = c[:, 0:c2.size(1), :] + c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1):, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + x1 = F.interpolate( + x1, scale_factor=4, mode='bilinear', align_corners=False) + x2 = F.interpolate( + x2, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.interpolate( + x4, scale_factor=0.5, mode='bilinear', align_corners=False) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/__init__.py new file mode 100644 index 00000000..9367806f --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/__init__.py @@ -0,0 +1,3 @@ +from .mask2former_head_from_mmseg import Mask2FormerHeadFromMMSeg + +__all__ = ['Mask2FormerHeadFromMMSeg'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/base_decode_head.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/base_decode_head.py new file mode 100644 index 00000000..36660520 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/base_decode_head.py @@ -0,0 +1,267 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 +from mmdet.models.builder import build_loss +from mmdet.models.losses import accuracy + +from ...utils import build_pixel_sampler, seg_resize + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super(BaseDecodeHead, self).__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.num_classes = num_classes + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.ignore_index = ignore_index + self.align_corners = align_corners + + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + seg_resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @auto_fp16() + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.losses(seg_logits, gt_semantic_seg) + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs) + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + @force_fp32(apply_to=('seg_logit', )) + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + loss = dict() + seg_logit = seg_resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logit, seg_label, ignore_index=self.ignore_index) + return loss diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/mask2former_head_from_mmseg.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/mask2former_head_from_mmseg.py new file mode 100644 index 00000000..ad8b1586 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/mask2former_head_from_mmseg.py @@ -0,0 +1,581 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.utils import get_uncertain_point_coords_with_randomness + +from .base_decode_head import BaseDecodeHead + + +@HEADS.register_module() +class Mask2FormerHeadFromMMSeg(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs): + super(Mask2FormerHeadFromMMSeg, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform='multiple_select', + **kwargs) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers. \ + attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence( + transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding( + positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + self.conv_seg = None # fix a bug here (conv_seg is not used) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, + gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, + mask_preds_list, gt_labels_list, + gt_masks_list, img_metas) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, + img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, + gt_labels, gt_points_masks, + img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, + gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries, )) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, + num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, + gt_labels_list, gt_masks_list, + img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1, 1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, + all_gt_labels_list, all_gt_masks_list, img_metas_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip( + losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros( + (batch_size, ) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, + gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, + img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]['ori_shape'] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_mask diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/__init__.py new file mode 100644 index 00000000..1f2c8b04 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/__init__.py @@ -0,0 +1,3 @@ +from .encoder_decoder_mask2former import EncoderDecoderMask2Former + +__all__ = ['EncoderDecoderMask2Former'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/base_segmentor.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/base_segmentor.py new file mode 100644 index 00000000..8bd8fa3f --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/base_segmentor.py @@ -0,0 +1,314 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import mmcv +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule, auto_fp16 + + +class BaseSegmentor(BaseModule, metaclass=ABCMeta): + """Base class for segmentors.""" + + def __init__(self, init_cfg=None): + super(BaseSegmentor, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got ' + f'{type(var)}') + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(imgs)}) != ' + f'num of image meta ({len(img_metas)})') + + # all images in the same aug batch all of the same ori_shape and pad + # shape + def tensor_to_tuple(input_tensor): + return tuple(input_tensor.cpu().numpy()) + + for img_meta in img_metas: + ori_shapes = [_['ori_shape'] for _ in img_meta] + if isinstance(ori_shapes[0], torch.Tensor): + assert all( + tensor_to_tuple(shape) == tensor_to_tuple(ori_shapes[0]) + for shape in ori_shapes) + else: + assert all(shape == ori_shapes[0] for shape in ori_shapes) + + img_shapes = [_['img_shape'] for _ in img_meta] + if isinstance(img_shapes[0], torch.Tensor): + assert all( + tensor_to_tuple(shape) == tensor_to_tuple(img_shapes[0]) + for shape in img_shapes) + else: + assert all(shape == img_shapes[0] for shape in img_shapes) + + pad_shapes = [_['pad_shape'] for _ in img_meta] + if isinstance(pad_shapes[0], torch.Tensor): + assert all( + tensor_to_tuple(shape) == tensor_to_tuple(pad_shapes[0]) + for shape in pad_shapes) + else: + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data_batch['img_metas'])) + + return outputs + + def val_step(self, data_batch, optimizer=None, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + log_vars_ = dict() + for loss_name, loss_value in log_vars.items(): + k = loss_name + '_val' + log_vars_[k] = loss_value + + outputs = dict( + loss=loss, + log_vars=log_vars_, + num_samples=len(data_batch['img_metas'])) + + return outputs + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + # If the loss_vars has different length, raise assertion error + # to prevent GPUs from infinite waiting. + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys()) + '\n') + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def show_result(self, + img, + result, + palette=None, + win_name='', + show=False, + wait_time=0, + out_file=None, + opacity=0.5): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (Tensor): The semantic segmentation results to draw over + `img`. + palette (list[list[int]]] | np.ndarray | None): The palette of + segmentation map. If None is given, random palette will be + generated. Default: None + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + img = mmcv.imread(img) + img = img.copy() + seg = result[0] + if palette is None: + if self.PALETTE is None: + # Get random state before set seed, + # and restore random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + palette = np.random.randint( + 0, 255, size=(len(self.CLASSES), 3)) + np.random.set_state(state) + else: + palette = self.PALETTE + palette = np.array(palette) + assert palette.shape[0] == len(self.CLASSES) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/encoder_decoder_mask2former.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/encoder_decoder_mask2former.py new file mode 100644 index 00000000..9287e8aa --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,303 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models import builder +from mmdet.models.builder import DETECTORS + +from ...utils import add_prefix, seg_resize +from .base_segmentor import BaseSegmentor + + +@DETECTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = seg_resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, + **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, + gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, + gt_semantic_seg, + **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train( + x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + # TODO refactor + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + + def tensor_to_tuple(input_tensor): + return tuple(input_tensor.cpu().numpy()) + + if rescale: + preds = seg_resize( + preds, + size=tensor_to_tuple(img_meta[0]['ori_shape'])[:2] + if isinstance(img_meta[0]['ori_shape'], torch.Tensor) else + img_meta[0]['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]['ori_shape'][:2] + seg_logit = seg_resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + + def tensor_to_tuple(input_tensor): + return tuple(input_tensor.cpu().numpy()) + + if isinstance(ori_shape, torch.Tensor): + assert all( + tensor_to_tuple(_['ori_shape']) == tensor_to_tuple(ori_shape) + for _ in img_meta) + else: + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]['flip'] + if flip: + flip_direction = img_meta[0]['flip_direction'] + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + output = output.flip(dims=(3, )) + elif flip_direction == 'vertical': + output = output.flip(dims=(2, )) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/__init__.py new file mode 100644 index 00000000..dec8a5f2 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/__init__.py @@ -0,0 +1,7 @@ +from .builder import build_pixel_sampler +from .data_process_func import ResizeToMultiple +from .seg_func import add_prefix, seg_resize + +__all__ = [ + 'seg_resize', 'add_prefix', 'build_pixel_sampler', 'ResizeToMultiple' +] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/builder.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/builder.py new file mode 100644 index 00000000..63d77fea --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/builder.py @@ -0,0 +1,11 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git +from mmcv.utils import Registry, build_from_cfg + +PIXEL_SAMPLERS = Registry('pixel sampler') + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/data_process_func.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/data_process_func.py new file mode 100644 index 00000000..194361af --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/data_process_func.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class ResizeToMultiple(object): + """Resize images & seg to multiple of divisor. + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def __call__(self, results): + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/seg_func.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/seg_func.py new file mode 100644 index 00000000..fba46b81 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/seg_func.py @@ -0,0 +1,48 @@ +# The implementation refers to the VitAdapter +# available at +# https://github.com/czczup/ViT-Adapter.git + +import warnings + +import torch.nn.functional as F + + +def seg_resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > input_w: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index d084a91b..f4b4ae3e 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline from .image_reid_person_pipeline import ImageReidPersonPipeline + from .image_semantic_segmentation_pipeline import ImageSemanticSegmentationPipeline from .image_style_transfer_pipeline import ImageStyleTransferPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline @@ -66,6 +67,8 @@ else: 'image_portrait_enhancement_pipeline': ['ImagePortraitEnhancementPipeline'], 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], + 'image_semantic_segmentation_pipeline': + ['ImageSemanticSegmentationPipeline'], 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], 'image_to_image_translation_pipeline': diff --git a/modelscope/pipelines/cv/image_semantic_segmentation_pipeline.py b/modelscope/pipelines/cv/image_semantic_segmentation_pipeline.py new file mode 100644 index 00000000..e3e1fd6b --- /dev/null +++ b/modelscope/pipelines/cv/image_semantic_segmentation_pipeline.py @@ -0,0 +1,95 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_segmentation, + module_name=Pipelines.image_semantic_segmentation) +class ImageSemanticSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image semantic segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + logger.info('semantic segmentation model, pipeline init') + + def preprocess(self, input: Input) -> Dict[str, Any]: + from mmdet.datasets.pipelines import Compose + from mmcv.parallel import collate, scatter + from mmdet.datasets import replace_ImageToTensor + + cfg = self.model.cfg + # build the data pipeline + + if isinstance(input, str): + # input is str, file names, pipeline loadimagefromfile + # collect data + data = dict(img_info=dict(filename=input), img_prefix=None) + elif isinstance(input, PIL.Image.Image): # BGR + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(input)[:, :, ::-1] + # collect data + data = dict(img=img) + elif isinstance(input, np.ndarray): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + img = input + # collect data + data = dict(img=img) + + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + # data = dict(img=input) + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + + data = test_pipeline(data) + # copy from mmdet_model collect data + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + if next(self.model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + results = self.model.postprocess(inputs) + outputs = { + OutputKeys.MASKS: results[OutputKeys.MASKS], + OutputKeys.LABELS: results[OutputKeys.LABELS], + OutputKeys.SCORES: results[OutputKeys.SCORES] + } + + return outputs diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index fca0e54f..9ded7ef3 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -153,3 +153,16 @@ def panoptic_seg_masks_to_image(masks): draw_img[mask] = color_mask return draw_img + + +def semantic_seg_masks_to_image(masks): + from mmdet.core.visualization.palette import get_palette + mask_palette = get_palette('coco', 133) + + draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) + + for i, mask in enumerate(masks): + color_mask = mask_palette[i] + mask = mask.astype(bool) + draw_img[mask] = color_mask + return draw_img diff --git a/tests/pipelines/test_image_semantic_segmentation.py b/tests/pipelines/test_image_semantic_segmentation.py new file mode 100644 index 00000000..6738976c --- /dev/null +++ b/tests/pipelines/test_image_semantic_segmentation.py @@ -0,0 +1,54 @@ +import unittest + +import cv2 +import PIL + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +class ImageSemanticSegmentationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_semantic_segmentation_panmerge(self): + input_location = 'data/test/images/image_semantic_segmentation.jpg' + model_id = 'damo/cv_swinL_semantic-segmentation_cocopanmerge' + segmenter = pipeline(Tasks.image_segmentation, model=model_id) + result = segmenter(input_location) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_panmerge DONE') + + PIL_array = PIL.Image.open(input_location) + result = segmenter(PIL_array) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_panmerge_from_PIL DONE') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_semantic_segmentation_vitadapter(self): + input_location = 'data/test/images/image_semantic_segmentation.jpg' + model_id = 'damo/cv_vitadapter_semantic-segmentation_cocostuff164k' + segmenter = pipeline(Tasks.image_segmentation, model=model_id) + result = segmenter(input_location) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_vitadapter DONE') + + PIL_array = PIL.Image.open(input_location) + result = segmenter(PIL_array) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_vitadapter_from_PIL DONE') + + +if __name__ == '__main__': + unittest.main()