Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9851374master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a | |||
size 245864 |
@@ -23,6 +23,8 @@ class Models(object): | |||
panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
image_reid_person = 'passvitb' | |||
video_summarization = 'pgl-video-summarization' | |||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
# nlp models | |||
bert = 'bert' | |||
@@ -117,6 +119,7 @@ class Pipelines(object): | |||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
image_semantic_segmentation = 'image-semantic-segmentation' | |||
image_reid_person = 'passvitb-image-reid-person' | |||
# nlp tasks | |||
@@ -4,8 +4,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
face_generation, image_classification, image_color_enhance, | |||
image_colorization, image_denoise, image_instance_segmentation, | |||
image_panoptic_segmentation, image_portrait_enhancement, | |||
image_reid_person, image_to_image_generation, | |||
image_to_image_translation, object_detection, | |||
product_retrieval_embedding, salient_detection, | |||
super_resolution, video_single_object_tracking, | |||
video_summarization, virual_tryon) | |||
image_reid_person, image_semantic_segmentation, | |||
image_to_image_generation, image_to_image_translation, | |||
object_detection, product_retrieval_embedding, | |||
salient_detection, super_resolution, | |||
video_single_object_tracking, video_summarization, virual_tryon) |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .semantic_seg_model import SemanticSegmentation | |||
else: | |||
_import_structure = { | |||
'semantic_seg_model': ['SemanticSegmentation'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1 @@ | |||
from .maskformer_semantic_head import MaskFormerSemanticHead |
@@ -0,0 +1,47 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
from abc import ABCMeta, abstractmethod | |||
from mmcv.runner import BaseModule | |||
from mmdet.models.builder import build_loss | |||
class BasePanopticFusionHead(BaseModule, metaclass=ABCMeta): | |||
"""Base class for panoptic heads.""" | |||
def __init__(self, | |||
num_things_classes=80, | |||
num_stuff_classes=53, | |||
test_cfg=None, | |||
loss_panoptic=None, | |||
init_cfg=None, | |||
**kwargs): | |||
super(BasePanopticFusionHead, self).__init__(init_cfg) | |||
self.num_things_classes = num_things_classes | |||
self.num_stuff_classes = num_stuff_classes | |||
self.num_classes = num_things_classes + num_stuff_classes | |||
self.test_cfg = test_cfg | |||
if loss_panoptic: | |||
self.loss_panoptic = build_loss(loss_panoptic) | |||
else: | |||
self.loss_panoptic = None | |||
@property | |||
def with_loss(self): | |||
"""bool: whether the panoptic head contains loss function.""" | |||
return self.loss_panoptic is not None | |||
@abstractmethod | |||
def forward_train(self, gt_masks=None, gt_semantic_seg=None, **kwargs): | |||
"""Forward function during training.""" | |||
@abstractmethod | |||
def simple_test(self, | |||
img_metas, | |||
det_labels, | |||
mask_preds, | |||
seg_preds, | |||
det_bboxes, | |||
cfg=None, | |||
**kwargs): | |||
"""Test without augmentation.""" |
@@ -0,0 +1,57 @@ | |||
import torch | |||
import torch.nn.functional as F | |||
from mmdet.models.builder import HEADS | |||
from .base_panoptic_fusion_head import BasePanopticFusionHead | |||
@HEADS.register_module() | |||
class MaskFormerSemanticHead(BasePanopticFusionHead): | |||
def __init__(self, | |||
num_things_classes=80, | |||
num_stuff_classes=53, | |||
test_cfg=None, | |||
loss_panoptic=None, | |||
init_cfg=None, | |||
**kwargs): | |||
super().__init__(num_things_classes, num_stuff_classes, test_cfg, | |||
loss_panoptic, init_cfg, **kwargs) | |||
def forward_train(self, **kwargs): | |||
"""MaskFormerFusionHead has no training loss.""" | |||
return dict() | |||
def simple_test(self, | |||
mask_cls_results, | |||
mask_pred_results, | |||
img_metas, | |||
rescale=False, | |||
**kwargs): | |||
results = [] | |||
for mask_cls_result, mask_pred_result, meta in zip( | |||
mask_cls_results, mask_pred_results, img_metas): | |||
# remove padding | |||
img_height, img_width = meta['img_shape'][:2] | |||
mask_pred_result = mask_pred_result[:, :img_height, :img_width] | |||
if rescale: | |||
# return result in original resolution | |||
ori_height, ori_width = meta['ori_shape'][:2] | |||
mask_pred_result = F.interpolate( | |||
mask_pred_result[:, None], | |||
size=(ori_height, ori_width), | |||
mode='bilinear', | |||
align_corners=False)[:, 0] | |||
# semantic inference | |||
cls_score = F.softmax(mask_cls_result, dim=-1)[..., :-1] | |||
mask_pred = mask_pred_result.sigmoid() | |||
seg_mask = torch.einsum('qc,qhw->chw', cls_score, mask_pred) | |||
# still need softmax and argmax | |||
seg_logit = F.softmax(seg_mask, dim=0) | |||
seg_pred = seg_logit.argmax(dim=0) | |||
seg_pred = seg_pred.cpu().numpy() | |||
results.append(seg_pred) | |||
return results |
@@ -0,0 +1,76 @@ | |||
import os.path as osp | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.image_semantic_segmentation import (pan_merge, | |||
vit_adapter) | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
@MODELS.register_module( | |||
Tasks.image_segmentation, module_name=Models.swinL_semantic_segmentation) | |||
@MODELS.register_module( | |||
Tasks.image_segmentation, | |||
module_name=Models.vitadapter_semantic_segmentation) | |||
class SemanticSegmentation(TorchModel): | |||
def __init__(self, model_dir: str, **kwargs): | |||
"""str -- model file root.""" | |||
super().__init__(model_dir, **kwargs) | |||
from mmcv.runner import load_checkpoint | |||
import mmcv | |||
from mmdet.models import build_detector | |||
config = osp.join(model_dir, 'mmcv_config.py') | |||
cfg = mmcv.Config.fromfile(config) | |||
if 'pretrained' in cfg.model: | |||
cfg.model.pretrained = None | |||
elif 'init_cfg' in cfg.model.backbone: | |||
cfg.model.backbone.init_cfg = None | |||
# build model | |||
cfg.model.train_cfg = None | |||
self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |||
# load model | |||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
_ = load_checkpoint(self.model, model_path, map_location='cpu') | |||
self.CLASSES = cfg['CLASSES'] # list | |||
self.PALETTE = cfg['PALETTE'] # list | |||
self.num_classes = len(self.CLASSES) | |||
self.cfg = cfg | |||
def forward(self, Inputs): | |||
return self.model(**Inputs) | |||
def postprocess(self, Inputs): | |||
semantic_result = Inputs[0] | |||
ids = np.unique(semantic_result)[::-1] | |||
legal_indices = ids != self.model.num_classes # for VOID label | |||
ids = ids[legal_indices] | |||
segms = (semantic_result[None] == ids[:, None, None]) | |||
masks = [it.astype(np.int) for it in segms] | |||
labels_txt = np.array(self.CLASSES)[ids].tolist() | |||
results = { | |||
OutputKeys.MASKS: masks, | |||
OutputKeys.LABELS: labels_txt, | |||
OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] | |||
} | |||
return results | |||
def inference(self, data): | |||
with torch.no_grad(): | |||
results = self.model(return_loss=False, rescale=True, **data) | |||
return results |
@@ -0,0 +1,3 @@ | |||
from .models import backbone, decode_heads, segmentors | |||
from .utils import (ResizeToMultiple, add_prefix, build_pixel_sampler, | |||
seg_resize) |
@@ -0,0 +1,3 @@ | |||
from .backbone import BASEBEiT, BEiTAdapter | |||
from .decode_heads import Mask2FormerHeadFromMMSeg | |||
from .segmentors import EncoderDecoderMask2Former |
@@ -0,0 +1,4 @@ | |||
from .base import BASEBEiT | |||
from .beit_adapter import BEiTAdapter | |||
__all__ = ['BEiTAdapter', 'BASEBEiT'] |
@@ -0,0 +1,523 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import logging | |||
from functools import partial | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint as cp | |||
from mmdet.models.utils.transformer import MultiScaleDeformableAttention | |||
from timm.models.layers import DropPath | |||
_logger = logging.getLogger(__name__) | |||
def get_reference_points(spatial_shapes, device): | |||
reference_points_list = [] | |||
for lvl, (H_, W_) in enumerate(spatial_shapes): | |||
ref_y, ref_x = torch.meshgrid( | |||
torch.linspace( | |||
0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), | |||
torch.linspace( | |||
0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) | |||
ref_y = ref_y.reshape(-1)[None] / H_ | |||
ref_x = ref_x.reshape(-1)[None] / W_ | |||
ref = torch.stack((ref_x, ref_y), -1) | |||
reference_points_list.append(ref) | |||
reference_points = torch.cat(reference_points_list, 1) | |||
reference_points = reference_points[:, :, None] | |||
return reference_points | |||
def deform_inputs(x): | |||
bs, c, h, w = x.shape | |||
spatial_shapes = torch.as_tensor([(h // 8, w // 8), (h // 16, w // 16), | |||
(h // 32, w // 32)], | |||
dtype=torch.long, | |||
device=x.device) | |||
level_start_index = torch.cat((spatial_shapes.new_zeros( | |||
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |||
reference_points = get_reference_points([(h // 16, w // 16)], x.device) | |||
deform_inputs1 = [reference_points, spatial_shapes, level_start_index] | |||
spatial_shapes = torch.as_tensor([(h // 16, w // 16)], | |||
dtype=torch.long, | |||
device=x.device) | |||
level_start_index = torch.cat((spatial_shapes.new_zeros( | |||
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |||
reference_points = get_reference_points([(h // 8, w // 8), | |||
(h // 16, w // 16), | |||
(h // 32, w // 32)], x.device) | |||
deform_inputs2 = [reference_points, spatial_shapes, level_start_index] | |||
return deform_inputs1, deform_inputs2 | |||
class ConvFFN(nn.Module): | |||
def __init__(self, | |||
in_features, | |||
hidden_features=None, | |||
out_features=None, | |||
act_layer=nn.GELU, | |||
drop=0.): | |||
super().__init__() | |||
out_features = out_features or in_features | |||
hidden_features = hidden_features or in_features | |||
self.fc1 = nn.Linear(in_features, hidden_features) | |||
self.dwconv = DWConv(hidden_features) | |||
self.act = act_layer() | |||
self.fc2 = nn.Linear(hidden_features, out_features) | |||
self.drop = nn.Dropout(drop) | |||
def forward(self, x, H, W): | |||
x = self.fc1(x) | |||
x = self.dwconv(x, H, W) | |||
x = self.act(x) | |||
x = self.drop(x) | |||
x = self.fc2(x) | |||
x = self.drop(x) | |||
return x | |||
class DWConv(nn.Module): | |||
def __init__(self, dim=768): | |||
super().__init__() | |||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) | |||
def forward(self, x, H, W): | |||
B, N, C = x.shape | |||
n = N // 21 | |||
x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, | |||
W * 2).contiguous() | |||
x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, | |||
W).contiguous() | |||
x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, | |||
W // 2).contiguous() | |||
x1 = self.dwconv(x1).flatten(2).transpose(1, 2) | |||
x2 = self.dwconv(x2).flatten(2).transpose(1, 2) | |||
x3 = self.dwconv(x3).flatten(2).transpose(1, 2) | |||
x = torch.cat([x1, x2, x3], dim=1) | |||
return x | |||
class Extractor(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads=6, | |||
n_points=4, | |||
n_levels=1, | |||
deform_ratio=1.0, | |||
with_cffn=True, | |||
cffn_ratio=0.25, | |||
drop=0., | |||
drop_path=0., | |||
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
with_cp=False): | |||
super().__init__() | |||
self.query_norm = norm_layer(dim) | |||
self.feat_norm = norm_layer(dim) | |||
self.attn = MultiScaleDeformableAttention( | |||
embed_dims=dim, | |||
num_heads=num_heads, | |||
num_levels=n_levels, | |||
num_points=n_points, | |||
batch_first=True) | |||
# modify to fit the deform_ratio | |||
value_proj_in_features = self.attn.value_proj.weight.shape[0] | |||
value_proj_out_features = int(value_proj_in_features * deform_ratio) | |||
self.attn.value_proj = nn.Linear(value_proj_in_features, | |||
value_proj_out_features) | |||
self.attn.output_proj = nn.Linear(value_proj_out_features, | |||
value_proj_in_features) | |||
self.with_cffn = with_cffn | |||
self.with_cp = with_cp | |||
if with_cffn: | |||
self.ffn = ConvFFN( | |||
in_features=dim, | |||
hidden_features=int(dim * cffn_ratio), | |||
drop=drop) | |||
self.ffn_norm = norm_layer(dim) | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
def forward(self, query, reference_points, feat, spatial_shapes, | |||
level_start_index, H, W): | |||
def _inner_forward(query, feat): | |||
attn = self.attn( | |||
query=self.query_norm(query), | |||
key=None, | |||
value=self.feat_norm(feat), | |||
identity=None, | |||
query_pos=None, | |||
key_padding_mask=None, | |||
reference_points=reference_points, | |||
spatial_shapes=spatial_shapes, | |||
level_start_index=level_start_index) | |||
query = query + attn | |||
if self.with_cffn: | |||
query = query + self.drop_path( | |||
self.ffn(self.ffn_norm(query), H, W)) | |||
return query | |||
if self.with_cp and query.requires_grad: | |||
query = cp.checkpoint(_inner_forward, query, feat) | |||
else: | |||
query = _inner_forward(query, feat) | |||
return query | |||
class Injector(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads=6, | |||
n_points=4, | |||
n_levels=1, | |||
deform_ratio=1.0, | |||
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
init_values=0., | |||
with_cp=False): | |||
super().__init__() | |||
self.with_cp = with_cp | |||
self.query_norm = norm_layer(dim) | |||
self.feat_norm = norm_layer(dim) | |||
self.attn = MultiScaleDeformableAttention( | |||
embed_dims=dim, | |||
num_heads=num_heads, | |||
num_levels=n_levels, | |||
num_points=n_points, | |||
batch_first=True) | |||
# modify to fit the deform_ratio | |||
value_proj_in_features = self.attn.value_proj.weight.shape[0] | |||
value_proj_out_features = int(value_proj_in_features * deform_ratio) | |||
self.attn.value_proj = nn.Linear(value_proj_in_features, | |||
value_proj_out_features) | |||
self.attn.output_proj = nn.Linear(value_proj_out_features, | |||
value_proj_in_features) | |||
self.gamma = nn.Parameter( | |||
init_values * torch.ones((dim)), requires_grad=True) | |||
def forward(self, query, reference_points, feat, spatial_shapes, | |||
level_start_index): | |||
def _inner_forward(query, feat): | |||
input_query = self.query_norm(query) | |||
input_value = self.feat_norm(feat) | |||
attn = self.attn( | |||
query=input_query, | |||
key=None, | |||
value=input_value, | |||
identity=None, | |||
query_pos=None, | |||
key_padding_mask=None, | |||
reference_points=reference_points, | |||
spatial_shapes=spatial_shapes, | |||
level_start_index=level_start_index) | |||
return query + self.gamma * attn | |||
if self.with_cp and query.requires_grad: | |||
query = cp.checkpoint(_inner_forward, query, feat) | |||
else: | |||
query = _inner_forward(query, feat) | |||
return query | |||
class InteractionBlock(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads=6, | |||
n_points=4, | |||
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
drop=0., | |||
drop_path=0., | |||
with_cffn=True, | |||
cffn_ratio=0.25, | |||
init_values=0., | |||
deform_ratio=1.0, | |||
extra_extractor=False, | |||
with_cp=False): | |||
super().__init__() | |||
self.injector = Injector( | |||
dim=dim, | |||
n_levels=3, | |||
num_heads=num_heads, | |||
init_values=init_values, | |||
n_points=n_points, | |||
norm_layer=norm_layer, | |||
deform_ratio=deform_ratio, | |||
with_cp=with_cp) | |||
self.extractor = Extractor( | |||
dim=dim, | |||
n_levels=1, | |||
num_heads=num_heads, | |||
n_points=n_points, | |||
norm_layer=norm_layer, | |||
deform_ratio=deform_ratio, | |||
with_cffn=with_cffn, | |||
cffn_ratio=cffn_ratio, | |||
drop=drop, | |||
drop_path=drop_path, | |||
with_cp=with_cp) | |||
if extra_extractor: | |||
self.extra_extractors = nn.Sequential(*[ | |||
Extractor( | |||
dim=dim, | |||
num_heads=num_heads, | |||
n_points=n_points, | |||
norm_layer=norm_layer, | |||
with_cffn=with_cffn, | |||
cffn_ratio=cffn_ratio, | |||
deform_ratio=deform_ratio, | |||
drop=drop, | |||
drop_path=drop_path, | |||
with_cp=with_cp) for _ in range(2) | |||
]) | |||
else: | |||
self.extra_extractors = None | |||
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): | |||
x = self.injector( | |||
query=x, | |||
reference_points=deform_inputs1[0], | |||
feat=c, | |||
spatial_shapes=deform_inputs1[1], | |||
level_start_index=deform_inputs1[2]) | |||
for idx, blk in enumerate(blocks): | |||
x = blk(x, H, W) | |||
c = self.extractor( | |||
query=c, | |||
reference_points=deform_inputs2[0], | |||
feat=x, | |||
spatial_shapes=deform_inputs2[1], | |||
level_start_index=deform_inputs2[2], | |||
H=H, | |||
W=W) | |||
if self.extra_extractors is not None: | |||
for extractor in self.extra_extractors: | |||
c = extractor( | |||
query=c, | |||
reference_points=deform_inputs2[0], | |||
feat=x, | |||
spatial_shapes=deform_inputs2[1], | |||
level_start_index=deform_inputs2[2], | |||
H=H, | |||
W=W) | |||
return x, c | |||
class InteractionBlockWithCls(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads=6, | |||
n_points=4, | |||
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
drop=0., | |||
drop_path=0., | |||
with_cffn=True, | |||
cffn_ratio=0.25, | |||
init_values=0., | |||
deform_ratio=1.0, | |||
extra_extractor=False, | |||
with_cp=False): | |||
super().__init__() | |||
self.injector = Injector( | |||
dim=dim, | |||
n_levels=3, | |||
num_heads=num_heads, | |||
init_values=init_values, | |||
n_points=n_points, | |||
norm_layer=norm_layer, | |||
deform_ratio=deform_ratio, | |||
with_cp=with_cp) | |||
self.extractor = Extractor( | |||
dim=dim, | |||
n_levels=1, | |||
num_heads=num_heads, | |||
n_points=n_points, | |||
norm_layer=norm_layer, | |||
deform_ratio=deform_ratio, | |||
with_cffn=with_cffn, | |||
cffn_ratio=cffn_ratio, | |||
drop=drop, | |||
drop_path=drop_path, | |||
with_cp=with_cp) | |||
if extra_extractor: | |||
self.extra_extractors = nn.Sequential(*[ | |||
Extractor( | |||
dim=dim, | |||
num_heads=num_heads, | |||
n_points=n_points, | |||
norm_layer=norm_layer, | |||
with_cffn=with_cffn, | |||
cffn_ratio=cffn_ratio, | |||
deform_ratio=deform_ratio, | |||
drop=drop, | |||
drop_path=drop_path, | |||
with_cp=with_cp) for _ in range(2) | |||
]) | |||
else: | |||
self.extra_extractors = None | |||
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): | |||
x = self.injector( | |||
query=x, | |||
reference_points=deform_inputs1[0], | |||
feat=c, | |||
spatial_shapes=deform_inputs1[1], | |||
level_start_index=deform_inputs1[2]) | |||
x = torch.cat((cls, x), dim=1) | |||
for idx, blk in enumerate(blocks): | |||
x = blk(x, H, W) | |||
cls, x = x[:, :1, ], x[:, 1:, ] | |||
c = self.extractor( | |||
query=c, | |||
reference_points=deform_inputs2[0], | |||
feat=x, | |||
spatial_shapes=deform_inputs2[1], | |||
level_start_index=deform_inputs2[2], | |||
H=H, | |||
W=W) | |||
if self.extra_extractors is not None: | |||
for extractor in self.extra_extractors: | |||
c = extractor( | |||
query=c, | |||
reference_points=deform_inputs2[0], | |||
feat=x, | |||
spatial_shapes=deform_inputs2[1], | |||
level_start_index=deform_inputs2[2], | |||
H=H, | |||
W=W) | |||
return x, c, cls | |||
class SpatialPriorModule(nn.Module): | |||
def __init__(self, inplanes=64, embed_dim=384, with_cp=False): | |||
super().__init__() | |||
self.with_cp = with_cp | |||
self.stem = nn.Sequential(*[ | |||
nn.Conv2d( | |||
3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), | |||
nn.SyncBatchNorm(inplanes), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
inplanes, | |||
inplanes, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False), | |||
nn.SyncBatchNorm(inplanes), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
inplanes, | |||
inplanes, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False), | |||
nn.SyncBatchNorm(inplanes), | |||
nn.ReLU(inplace=True), | |||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
]) | |||
self.conv2 = nn.Sequential(*[ | |||
nn.Conv2d( | |||
inplanes, | |||
2 * inplanes, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=False), | |||
nn.SyncBatchNorm(2 * inplanes), | |||
nn.ReLU(inplace=True) | |||
]) | |||
self.conv3 = nn.Sequential(*[ | |||
nn.Conv2d( | |||
2 * inplanes, | |||
4 * inplanes, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=False), | |||
nn.SyncBatchNorm(4 * inplanes), | |||
nn.ReLU(inplace=True) | |||
]) | |||
self.conv4 = nn.Sequential(*[ | |||
nn.Conv2d( | |||
4 * inplanes, | |||
4 * inplanes, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=False), | |||
nn.SyncBatchNorm(4 * inplanes), | |||
nn.ReLU(inplace=True) | |||
]) | |||
self.fc1 = nn.Conv2d( | |||
inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) | |||
self.fc2 = nn.Conv2d( | |||
2 * inplanes, | |||
embed_dim, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
bias=True) | |||
self.fc3 = nn.Conv2d( | |||
4 * inplanes, | |||
embed_dim, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
bias=True) | |||
self.fc4 = nn.Conv2d( | |||
4 * inplanes, | |||
embed_dim, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
bias=True) | |||
def forward(self, x): | |||
def _inner_forward(x): | |||
c1 = self.stem(x) | |||
c2 = self.conv2(c1) | |||
c3 = self.conv3(c2) | |||
c4 = self.conv4(c3) | |||
c1 = self.fc1(c1) | |||
c2 = self.fc2(c2) | |||
c3 = self.fc3(c3) | |||
c4 = self.fc4(c4) | |||
bs, dim, _, _ = c1.shape | |||
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s | |||
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s | |||
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s | |||
return c1, c2, c3, c4 | |||
if self.with_cp and x.requires_grad: | |||
outs = cp.checkpoint(_inner_forward, x) | |||
else: | |||
outs = _inner_forward(x) | |||
return outs |
@@ -0,0 +1,3 @@ | |||
from .beit import BASEBEiT | |||
__all__ = ['BASEBEiT'] |
@@ -0,0 +1,476 @@ | |||
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |||
# Github source: https://github.com/microsoft/unilm/tree/master/beit | |||
# This implementation refers to | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import math | |||
from functools import partial | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.utils.checkpoint as cp | |||
from mmcv.runner import _load_checkpoint | |||
from mmdet.models.builder import BACKBONES | |||
from mmdet.utils import get_root_logger | |||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_ | |||
class DropPath(nn.Module): | |||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of | |||
residual blocks).""" | |||
def __init__(self, drop_prob=None): | |||
super(DropPath, self).__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, x): | |||
return drop_path(x, self.drop_prob, self.training) | |||
def extra_repr(self) -> str: | |||
return 'p={}'.format(self.drop_prob) | |||
class Mlp(nn.Module): | |||
def __init__(self, | |||
in_features, | |||
hidden_features=None, | |||
out_features=None, | |||
act_layer=nn.GELU, | |||
drop=0.): | |||
super().__init__() | |||
out_features = out_features or in_features | |||
hidden_features = hidden_features or in_features | |||
self.fc1 = nn.Linear(in_features, hidden_features) | |||
self.act = act_layer() | |||
self.fc2 = nn.Linear(hidden_features, out_features) | |||
self.drop = nn.Dropout(drop) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act(x) | |||
# commit dropout for the original BERT implement | |||
x = self.fc2(x) | |||
x = self.drop(x) | |||
return x | |||
class Attention(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads=8, | |||
qkv_bias=False, | |||
qk_scale=None, | |||
attn_drop=0., | |||
proj_drop=0., | |||
window_size=None, | |||
attn_head_dim=None): | |||
super().__init__() | |||
self.num_heads = num_heads | |||
head_dim = dim // num_heads | |||
if attn_head_dim is not None: | |||
head_dim = attn_head_dim | |||
all_head_dim = head_dim * self.num_heads | |||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |||
self.scale = qk_scale or head_dim**-0.5 | |||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) | |||
if qkv_bias: | |||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) | |||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) | |||
else: | |||
self.q_bias = None | |||
self.v_bias = None | |||
if window_size: | |||
self.window_size = window_size | |||
self.num_relative_distance = (2 * window_size[0] | |||
- 1) * (2 * window_size[1] - 1) + 3 | |||
self.relative_position_bias_table = nn.Parameter( | |||
torch.zeros(self.num_relative_distance, | |||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
# cls to token & token 2 cls & cls to cls | |||
# get pair-wise relative position index for each token inside the window | |||
coords_h = torch.arange(window_size[0]) | |||
coords_w = torch.arange(window_size[1]) | |||
coords = torch.stack(torch.meshgrid([coords_h, | |||
coords_w])) # 2, Wh, Ww | |||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
relative_coords = coords_flatten[:, :, | |||
None] - coords_flatten[:, | |||
None, :] # 2, Wh*Ww, Wh*Ww | |||
relative_coords = relative_coords.permute( | |||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
relative_coords[:, :, | |||
0] += window_size[0] - 1 # shift to start from 0 | |||
relative_coords[:, :, 1] += window_size[1] - 1 | |||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1 | |||
relative_position_index = \ | |||
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) | |||
relative_position_index[1:, 1:] = relative_coords.sum( | |||
-1) # Wh*Ww, Wh*Ww | |||
relative_position_index[0, 0:] = self.num_relative_distance - 3 | |||
relative_position_index[0:, 0] = self.num_relative_distance - 2 | |||
relative_position_index[0, 0] = self.num_relative_distance - 1 | |||
self.register_buffer('relative_position_index', | |||
relative_position_index) | |||
else: | |||
self.window_size = None | |||
self.relative_position_bias_table = None | |||
self.relative_position_index = None | |||
self.attn_drop = nn.Dropout(attn_drop) | |||
self.proj = nn.Linear(all_head_dim, dim) | |||
self.proj_drop = nn.Dropout(proj_drop) | |||
def forward(self, x, rel_pos_bias=None): | |||
B, N, C = x.shape | |||
qkv_bias = None | |||
if self.q_bias is not None: | |||
qkv_bias = torch.cat( | |||
(self.q_bias, | |||
torch.zeros_like(self.v_bias, | |||
requires_grad=False), self.v_bias)) | |||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) | |||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |||
q, k, v = qkv[0], qkv[1], qkv[ | |||
2] # make torchscript happy (cannot use tensor as tuple) | |||
q = q * self.scale | |||
attn = (q @ k.transpose(-2, -1)) | |||
if self.relative_position_bias_table is not None: | |||
relative_position_bias = \ | |||
self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |||
self.window_size[0] * self.window_size[1] + 1, | |||
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH | |||
relative_position_bias = relative_position_bias.permute( | |||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
attn = attn + relative_position_bias.unsqueeze(0) | |||
if rel_pos_bias is not None: | |||
attn = attn + rel_pos_bias | |||
attn = attn.softmax(dim=-1) | |||
attn = self.attn_drop(attn) | |||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |||
x = self.proj(x) | |||
x = self.proj_drop(x) | |||
return x | |||
class Block(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads, | |||
mlp_ratio=4., | |||
qkv_bias=False, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
init_values=None, | |||
act_layer=nn.GELU, | |||
norm_layer=nn.LayerNorm, | |||
window_size=None, | |||
attn_head_dim=None, | |||
with_cp=False): | |||
super().__init__() | |||
self.with_cp = with_cp | |||
self.norm1 = norm_layer(dim) | |||
self.attn = Attention( | |||
dim, | |||
num_heads=num_heads, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
attn_drop=attn_drop, | |||
proj_drop=drop, | |||
window_size=window_size, | |||
attn_head_dim=attn_head_dim) | |||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
self.norm2 = norm_layer(dim) | |||
mlp_hidden_dim = int(dim * mlp_ratio) | |||
self.mlp = Mlp( | |||
in_features=dim, | |||
hidden_features=mlp_hidden_dim, | |||
act_layer=act_layer, | |||
drop=drop) | |||
if init_values is not None: | |||
self.gamma_1 = nn.Parameter( | |||
init_values * torch.ones((dim)), requires_grad=True) | |||
self.gamma_2 = nn.Parameter( | |||
init_values * torch.ones((dim)), requires_grad=True) | |||
else: | |||
self.gamma_1, self.gamma_2 = None, None | |||
def forward(self, x, H, W, rel_pos_bias=None): | |||
def _inner_forward(x): | |||
if self.gamma_1 is None: | |||
x = x + self.drop_path( | |||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) | |||
x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
else: | |||
x = x + self.drop_path(self.gamma_1 * self.attn( | |||
self.norm1(x), rel_pos_bias=rel_pos_bias)) | |||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) | |||
return x | |||
if self.with_cp and x.requires_grad: | |||
x = cp.checkpoint(_inner_forward, x) | |||
else: | |||
x = _inner_forward(x) | |||
return x | |||
class PatchEmbed(nn.Module): | |||
""" Image to Patch Embedding | |||
""" | |||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |||
super().__init__() | |||
img_size = to_2tuple(img_size) | |||
patch_size = to_2tuple(patch_size) | |||
num_patches = (img_size[1] // patch_size[1]) * ( | |||
img_size[0] // patch_size[0]) | |||
self.patch_shape = (img_size[0] // patch_size[0], | |||
img_size[1] // patch_size[1]) | |||
self.img_size = img_size | |||
self.patch_size = patch_size | |||
self.num_patches = num_patches | |||
self.proj = nn.Conv2d( | |||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
def forward(self, x, **kwargs): | |||
B, C, H, W = x.shape | |||
# FIXME look at relaxing size constraints | |||
# assert H == self.img_size[0] and W == self.img_size[1], \ | |||
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | |||
x = self.proj(x) | |||
Hp, Wp = x.shape[2], x.shape[3] | |||
x = x.flatten(2).transpose(1, 2) | |||
return x, Hp, Wp | |||
class HybridEmbed(nn.Module): | |||
""" CNN Feature Map Embedding | |||
Extract feature map from CNN, flatten, project to embedding dim. | |||
""" | |||
def __init__(self, | |||
backbone, | |||
img_size=224, | |||
feature_size=None, | |||
in_chans=3, | |||
embed_dim=768): | |||
super().__init__() | |||
assert isinstance(backbone, nn.Module) | |||
img_size = to_2tuple(img_size) | |||
self.img_size = img_size | |||
self.backbone = backbone | |||
if feature_size is None: | |||
with torch.no_grad(): | |||
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature | |||
# map for all networks, the feature metadata has reliable channel and stride info, but using | |||
# stride to calc feature dim requires info about padding of each stage that isn't captured. | |||
training = backbone.training | |||
if training: | |||
backbone.eval() | |||
o = self.backbone( | |||
torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] | |||
feature_size = o.shape[-2:] | |||
feature_dim = o.shape[1] | |||
backbone.train(training) | |||
else: | |||
feature_size = to_2tuple(feature_size) | |||
feature_dim = self.backbone.feature_info.channels()[-1] | |||
self.num_patches = feature_size[0] * feature_size[1] | |||
self.proj = nn.Linear(feature_dim, embed_dim) | |||
def forward(self, x): | |||
x = self.backbone(x)[-1] | |||
x = x.flatten(2).transpose(1, 2) | |||
x = self.proj(x) | |||
return x | |||
class RelativePositionBias(nn.Module): | |||
def __init__(self, window_size, num_heads): | |||
super().__init__() | |||
self.window_size = window_size | |||
self.num_relative_distance = (2 * window_size[0] | |||
- 1) * (2 * window_size[1] - 1) + 3 | |||
self.relative_position_bias_table = nn.Parameter( | |||
torch.zeros(self.num_relative_distance, | |||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
# cls to token & token 2 cls & cls to cls | |||
# get pair-wise relative position index for each token inside the window | |||
coords_h = torch.arange(window_size[0]) | |||
coords_w = torch.arange(window_size[1]) | |||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
relative_coords = coords_flatten[:, :, | |||
None] - coords_flatten[:, | |||
None, :] # 2, Wh*Ww, Wh*Ww | |||
relative_coords = relative_coords.permute( | |||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 | |||
relative_coords[:, :, 1] += window_size[1] - 1 | |||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1 | |||
relative_position_index = \ | |||
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) | |||
relative_position_index[1:, | |||
1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||
relative_position_index[0, 0:] = self.num_relative_distance - 3 | |||
relative_position_index[0:, 0] = self.num_relative_distance - 2 | |||
relative_position_index[0, 0] = self.num_relative_distance - 1 | |||
self.register_buffer('relative_position_index', | |||
relative_position_index) | |||
def forward(self): | |||
relative_position_bias = \ | |||
self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |||
self.window_size[0] * self.window_size[1] + 1, | |||
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH | |||
return relative_position_bias.permute( | |||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
@BACKBONES.register_module() | |||
class BASEBEiT(nn.Module): | |||
""" Vision Transformer with support for patch or hybrid CNN input stage | |||
""" | |||
def __init__(self, | |||
img_size=512, | |||
patch_size=16, | |||
in_chans=3, | |||
num_classes=80, | |||
embed_dim=768, | |||
depth=12, | |||
num_heads=12, | |||
mlp_ratio=4., | |||
qkv_bias=False, | |||
qk_scale=None, | |||
drop_rate=0., | |||
attn_drop_rate=0., | |||
drop_path_rate=0., | |||
hybrid_backbone=None, | |||
norm_layer=None, | |||
init_values=None, | |||
use_checkpoint=False, | |||
use_abs_pos_emb=False, | |||
use_rel_pos_bias=True, | |||
use_shared_rel_pos_bias=False, | |||
pretrained=None, | |||
with_cp=False): | |||
super().__init__() | |||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |||
self.norm_layer = norm_layer | |||
self.num_classes = num_classes | |||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||
self.drop_path_rate = drop_path_rate | |||
if hybrid_backbone is not None: | |||
self.patch_embed = HybridEmbed( | |||
hybrid_backbone, | |||
img_size=img_size, | |||
in_chans=in_chans, | |||
embed_dim=embed_dim) | |||
else: | |||
self.patch_embed = PatchEmbed( | |||
img_size=img_size, | |||
patch_size=patch_size, | |||
in_chans=in_chans, | |||
embed_dim=embed_dim) | |||
num_patches = self.patch_embed.num_patches | |||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |||
if use_abs_pos_emb: | |||
self.pos_embed = nn.Parameter( | |||
torch.zeros(1, num_patches + 1, embed_dim)) | |||
else: | |||
self.pos_embed = None | |||
self.pos_drop = nn.Dropout(p=drop_rate) | |||
if use_shared_rel_pos_bias: | |||
self.rel_pos_bias = RelativePositionBias( | |||
window_size=self.patch_embed.patch_shape, num_heads=num_heads) | |||
else: | |||
self.rel_pos_bias = None | |||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |||
] # stochastic depth decay rule | |||
self.use_rel_pos_bias = use_rel_pos_bias | |||
self.use_checkpoint = use_checkpoint | |||
self.blocks = nn.ModuleList([ | |||
Block( | |||
dim=embed_dim, | |||
num_heads=num_heads, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
drop=drop_rate, | |||
attn_drop=attn_drop_rate, | |||
drop_path=dpr[i], | |||
norm_layer=norm_layer, | |||
with_cp=with_cp, | |||
init_values=init_values, | |||
window_size=self.patch_embed.patch_shape | |||
if use_rel_pos_bias else None) for i in range(depth) | |||
]) | |||
trunc_normal_(self.cls_token, std=.02) | |||
self.apply(self._init_weights) | |||
self.init_weights(pretrained) | |||
def init_weights(self, pretrained=None): | |||
"""Initialize the weights in backbone. | |||
Args: | |||
pretrained (str, optional): Path to pre-trained weights. | |||
Defaults to None. | |||
""" | |||
if isinstance(pretrained, str): | |||
logger = get_root_logger() | |||
init_cfg = dict(type='Pretrained', checkpoint=pretrained) | |||
checkpoint = _load_checkpoint( | |||
init_cfg['checkpoint'], logger=logger, map_location='cpu') | |||
state_dict = self.resize_rel_pos_embed(checkpoint) | |||
self.load_state_dict(state_dict, False) | |||
def fix_init_weight(self): | |||
def rescale(param, layer_id): | |||
param.div_(math.sqrt(2.0 * layer_id)) | |||
for layer_id, layer in enumerate(self.blocks): | |||
rescale(layer.attn.proj.weight.data, layer_id + 1) | |||
rescale(layer.mlp.fc2.weight.data, layer_id + 1) | |||
def _init_weights(self, m): | |||
if isinstance(m, nn.Linear): | |||
trunc_normal_(m.weight, std=.02) | |||
if isinstance(m, nn.Linear) and m.bias is not None: | |||
nn.init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.LayerNorm): | |||
nn.init.constant_(m.bias, 0) | |||
nn.init.constant_(m.weight, 1.0) | |||
def get_num_layers(self): | |||
return len(self.blocks) |
@@ -0,0 +1,169 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import logging | |||
import math | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from mmdet.models.builder import BACKBONES | |||
from mmdet.models.utils.transformer import MultiScaleDeformableAttention | |||
from timm.models.layers import DropPath, trunc_normal_ | |||
from torch.nn.init import normal_ | |||
from .adapter_modules import InteractionBlockWithCls as InteractionBlock | |||
from .adapter_modules import SpatialPriorModule, deform_inputs | |||
from .base.beit import BASEBEiT | |||
_logger = logging.getLogger(__name__) | |||
@BACKBONES.register_module() | |||
class BEiTAdapter(BASEBEiT): | |||
def __init__(self, | |||
pretrain_size=224, | |||
conv_inplane=64, | |||
n_points=4, | |||
deform_num_heads=6, | |||
init_values=0., | |||
cffn_ratio=0.25, | |||
deform_ratio=1.0, | |||
with_cffn=True, | |||
interaction_indexes=None, | |||
add_vit_feature=True, | |||
with_cp=False, | |||
*args, | |||
**kwargs): | |||
super().__init__( | |||
init_values=init_values, with_cp=with_cp, *args, **kwargs) | |||
self.num_block = len(self.blocks) | |||
self.pretrain_size = (pretrain_size, pretrain_size) | |||
self.flags = [ | |||
i for i in range(-1, self.num_block, self.num_block // 4) | |||
][1:] | |||
self.interaction_indexes = interaction_indexes | |||
self.add_vit_feature = add_vit_feature | |||
embed_dim = self.embed_dim | |||
self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) | |||
self.spm = SpatialPriorModule( | |||
inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) | |||
self.interactions = nn.Sequential(*[ | |||
InteractionBlock( | |||
dim=embed_dim, | |||
num_heads=deform_num_heads, | |||
n_points=n_points, | |||
init_values=init_values, | |||
drop_path=self.drop_path_rate, | |||
norm_layer=self.norm_layer, | |||
with_cffn=with_cffn, | |||
cffn_ratio=cffn_ratio, | |||
deform_ratio=deform_ratio, | |||
extra_extractor=True if i == len(interaction_indexes) | |||
- 1 else False, | |||
with_cp=with_cp) for i in range(len(interaction_indexes)) | |||
]) | |||
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) | |||
self.norm1 = nn.SyncBatchNorm(embed_dim) | |||
self.norm2 = nn.SyncBatchNorm(embed_dim) | |||
self.norm3 = nn.SyncBatchNorm(embed_dim) | |||
self.norm4 = nn.SyncBatchNorm(embed_dim) | |||
self.up.apply(self._init_weights) | |||
self.spm.apply(self._init_weights) | |||
self.interactions.apply(self._init_weights) | |||
self.apply(self._init_deform_weights) | |||
normal_(self.level_embed) | |||
def _init_weights(self, m): | |||
if isinstance(m, nn.Linear): | |||
trunc_normal_(m.weight, std=.02) | |||
if isinstance(m, nn.Linear) and m.bias is not None: | |||
nn.init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): | |||
nn.init.constant_(m.bias, 0) | |||
nn.init.constant_(m.weight, 1.0) | |||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
fan_out //= m.groups | |||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |||
if m.bias is not None: | |||
m.bias.data.zero_() | |||
def _get_pos_embed(self, pos_embed, H, W): | |||
pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, | |||
self.pretrain_size[1] // 16, | |||
-1).permute(0, 3, 1, 2) | |||
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ | |||
reshape(1, -1, H * W).permute(0, 2, 1) | |||
return pos_embed | |||
def _init_deform_weights(self, m): | |||
if isinstance(m, MultiScaleDeformableAttention): | |||
m.init_weights() | |||
def _add_level_embed(self, c2, c3, c4): | |||
c2 = c2 + self.level_embed[0] | |||
c3 = c3 + self.level_embed[1] | |||
c4 = c4 + self.level_embed[2] | |||
return c2, c3, c4 | |||
def forward(self, x): | |||
deform_inputs1, deform_inputs2 = deform_inputs(x) | |||
# SPM forward | |||
c1, c2, c3, c4 = self.spm(x) | |||
c2, c3, c4 = self._add_level_embed(c2, c3, c4) | |||
c = torch.cat([c2, c3, c4], dim=1) | |||
# Patch Embedding forward | |||
x, H, W = self.patch_embed(x) | |||
bs, n, dim = x.shape | |||
cls = self.cls_token.expand( | |||
bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
if self.pos_embed is not None: | |||
pos_embed = self._get_pos_embed(self.pos_embed, H, W) | |||
x = x + pos_embed | |||
x = self.pos_drop(x) | |||
# Interaction | |||
outs = list() | |||
for i, layer in enumerate(self.interactions): | |||
indexes = self.interaction_indexes[i] | |||
x, c, cls = layer(x, c, cls, | |||
self.blocks[indexes[0]:indexes[-1] + 1], | |||
deform_inputs1, deform_inputs2, H, W) | |||
outs.append(x.transpose(1, 2).view(bs, dim, H, W).contiguous()) | |||
# Split & Reshape | |||
c2 = c[:, 0:c2.size(1), :] | |||
c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :] | |||
c4 = c[:, c2.size(1) + c3.size(1):, :] | |||
c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() | |||
c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() | |||
c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() | |||
c1 = self.up(c2) + c1 | |||
if self.add_vit_feature: | |||
x1, x2, x3, x4 = outs | |||
x1 = F.interpolate( | |||
x1, scale_factor=4, mode='bilinear', align_corners=False) | |||
x2 = F.interpolate( | |||
x2, scale_factor=2, mode='bilinear', align_corners=False) | |||
x4 = F.interpolate( | |||
x4, scale_factor=0.5, mode='bilinear', align_corners=False) | |||
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 | |||
# Final Norm | |||
f1 = self.norm1(c1) | |||
f2 = self.norm2(c2) | |||
f3 = self.norm3(c3) | |||
f4 = self.norm4(c4) | |||
return [f1, f2, f3, f4] |
@@ -0,0 +1,3 @@ | |||
from .mask2former_head_from_mmseg import Mask2FormerHeadFromMMSeg | |||
__all__ = ['Mask2FormerHeadFromMMSeg'] |
@@ -0,0 +1,267 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
from abc import ABCMeta, abstractmethod | |||
import torch | |||
import torch.nn as nn | |||
from mmcv.runner import BaseModule, auto_fp16, force_fp32 | |||
from mmdet.models.builder import build_loss | |||
from mmdet.models.losses import accuracy | |||
from ...utils import build_pixel_sampler, seg_resize | |||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta): | |||
"""Base class for BaseDecodeHead. | |||
Args: | |||
in_channels (int|Sequence[int]): Input channels. | |||
channels (int): Channels after modules, before conv_seg. | |||
num_classes (int): Number of classes. | |||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |||
conv_cfg (dict|None): Config of conv layers. Default: None. | |||
norm_cfg (dict|None): Config of norm layers. Default: None. | |||
act_cfg (dict): Config of activation layers. | |||
Default: dict(type='ReLU') | |||
in_index (int|Sequence[int]): Input feature index. Default: -1 | |||
input_transform (str|None): Transformation type of input features. | |||
Options: 'resize_concat', 'multiple_select', None. | |||
'resize_concat': Multiple feature maps will be resize to the | |||
same size as first one and than concat together. | |||
Usually used in FCN head of HRNet. | |||
'multiple_select': Multiple feature maps will be bundle into | |||
a list and passed into decode head. | |||
None: Only one select feature map is allowed. | |||
Default: None. | |||
loss_decode (dict | Sequence[dict]): Config of decode loss. | |||
The `loss_name` is property of corresponding loss function which | |||
could be shown in training log. If you want this loss | |||
item to be included into the backward graph, `loss_` must be the | |||
prefix of the name. Defaults to 'loss_ce'. | |||
e.g. dict(type='CrossEntropyLoss'), | |||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'), | |||
dict(type='DiceLoss', loss_name='loss_dice')] | |||
Default: dict(type='CrossEntropyLoss'). | |||
ignore_index (int | None): The label index to be ignored. When using | |||
masked BCE loss, ignore_index should be set to None. Default: 255. | |||
sampler (dict|None): The config of segmentation map sampler. | |||
Default: None. | |||
align_corners (bool): align_corners argument of F.interpolate. | |||
Default: False. | |||
init_cfg (dict or list[dict], optional): Initialization config dict. | |||
""" | |||
def __init__(self, | |||
in_channels, | |||
channels, | |||
*, | |||
num_classes, | |||
dropout_ratio=0.1, | |||
conv_cfg=None, | |||
norm_cfg=None, | |||
act_cfg=dict(type='ReLU'), | |||
in_index=-1, | |||
input_transform=None, | |||
loss_decode=dict( | |||
type='CrossEntropyLoss', | |||
use_sigmoid=False, | |||
loss_weight=1.0), | |||
ignore_index=255, | |||
sampler=None, | |||
align_corners=False, | |||
init_cfg=dict( | |||
type='Normal', std=0.01, override=dict(name='conv_seg'))): | |||
super(BaseDecodeHead, self).__init__(init_cfg) | |||
self._init_inputs(in_channels, in_index, input_transform) | |||
self.channels = channels | |||
self.num_classes = num_classes | |||
self.dropout_ratio = dropout_ratio | |||
self.conv_cfg = conv_cfg | |||
self.norm_cfg = norm_cfg | |||
self.act_cfg = act_cfg | |||
self.in_index = in_index | |||
self.ignore_index = ignore_index | |||
self.align_corners = align_corners | |||
if isinstance(loss_decode, dict): | |||
self.loss_decode = build_loss(loss_decode) | |||
elif isinstance(loss_decode, (list, tuple)): | |||
self.loss_decode = nn.ModuleList() | |||
for loss in loss_decode: | |||
self.loss_decode.append(build_loss(loss)) | |||
else: | |||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\ | |||
but got {type(loss_decode)}') | |||
if sampler is not None: | |||
self.sampler = build_pixel_sampler(sampler, context=self) | |||
else: | |||
self.sampler = None | |||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |||
if dropout_ratio > 0: | |||
self.dropout = nn.Dropout2d(dropout_ratio) | |||
else: | |||
self.dropout = None | |||
self.fp16_enabled = False | |||
def extra_repr(self): | |||
"""Extra repr.""" | |||
s = f'input_transform={self.input_transform}, ' \ | |||
f'ignore_index={self.ignore_index}, ' \ | |||
f'align_corners={self.align_corners}' | |||
return s | |||
def _init_inputs(self, in_channels, in_index, input_transform): | |||
"""Check and initialize input transforms. | |||
The in_channels, in_index and input_transform must match. | |||
Specifically, when input_transform is None, only single feature map | |||
will be selected. So in_channels and in_index must be of type int. | |||
When input_transform | |||
Args: | |||
in_channels (int|Sequence[int]): Input channels. | |||
in_index (int|Sequence[int]): Input feature index. | |||
input_transform (str|None): Transformation type of input features. | |||
Options: 'resize_concat', 'multiple_select', None. | |||
'resize_concat': Multiple feature maps will be resize to the | |||
same size as first one and than concat together. | |||
Usually used in FCN head of HRNet. | |||
'multiple_select': Multiple feature maps will be bundle into | |||
a list and passed into decode head. | |||
None: Only one select feature map is allowed. | |||
""" | |||
if input_transform is not None: | |||
assert input_transform in ['resize_concat', 'multiple_select'] | |||
self.input_transform = input_transform | |||
self.in_index = in_index | |||
if input_transform is not None: | |||
assert isinstance(in_channels, (list, tuple)) | |||
assert isinstance(in_index, (list, tuple)) | |||
assert len(in_channels) == len(in_index) | |||
if input_transform == 'resize_concat': | |||
self.in_channels = sum(in_channels) | |||
else: | |||
self.in_channels = in_channels | |||
else: | |||
assert isinstance(in_channels, int) | |||
assert isinstance(in_index, int) | |||
self.in_channels = in_channels | |||
def _transform_inputs(self, inputs): | |||
"""Transform inputs for decoder. | |||
Args: | |||
inputs (list[Tensor]): List of multi-level img features. | |||
Returns: | |||
Tensor: The transformed inputs | |||
""" | |||
if self.input_transform == 'resize_concat': | |||
inputs = [inputs[i] for i in self.in_index] | |||
upsampled_inputs = [ | |||
seg_resize( | |||
input=x, | |||
size=inputs[0].shape[2:], | |||
mode='bilinear', | |||
align_corners=self.align_corners) for x in inputs | |||
] | |||
inputs = torch.cat(upsampled_inputs, dim=1) | |||
elif self.input_transform == 'multiple_select': | |||
inputs = [inputs[i] for i in self.in_index] | |||
else: | |||
inputs = inputs[self.in_index] | |||
return inputs | |||
@auto_fp16() | |||
@abstractmethod | |||
def forward(self, inputs): | |||
"""Placeholder of forward function.""" | |||
pass | |||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): | |||
"""Forward function for training. | |||
Args: | |||
inputs (list[Tensor]): List of multi-level img features. | |||
img_metas (list[dict]): List of image info dict where each dict | |||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
`mmseg/datasets/pipelines/formatting.py:Collect`. | |||
gt_semantic_seg (Tensor): Semantic segmentation masks | |||
used if the architecture supports semantic segmentation task. | |||
train_cfg (dict): The training config. | |||
Returns: | |||
dict[str, Tensor]: a dictionary of loss components | |||
""" | |||
seg_logits = self.forward(inputs) | |||
losses = self.losses(seg_logits, gt_semantic_seg) | |||
return losses | |||
def forward_test(self, inputs, img_metas, test_cfg): | |||
"""Forward function for testing. | |||
Args: | |||
inputs (list[Tensor]): List of multi-level img features. | |||
img_metas (list[dict]): List of image info dict where each dict | |||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
`mmseg/datasets/pipelines/formatting.py:Collect`. | |||
test_cfg (dict): The testing config. | |||
Returns: | |||
Tensor: Output segmentation map. | |||
""" | |||
return self.forward(inputs) | |||
def cls_seg(self, feat): | |||
"""Classify each pixel.""" | |||
if self.dropout is not None: | |||
feat = self.dropout(feat) | |||
output = self.conv_seg(feat) | |||
return output | |||
@force_fp32(apply_to=('seg_logit', )) | |||
def losses(self, seg_logit, seg_label): | |||
"""Compute segmentation loss.""" | |||
loss = dict() | |||
seg_logit = seg_resize( | |||
input=seg_logit, | |||
size=seg_label.shape[2:], | |||
mode='bilinear', | |||
align_corners=self.align_corners) | |||
if self.sampler is not None: | |||
seg_weight = self.sampler.sample(seg_logit, seg_label) | |||
else: | |||
seg_weight = None | |||
seg_label = seg_label.squeeze(1) | |||
if not isinstance(self.loss_decode, nn.ModuleList): | |||
losses_decode = [self.loss_decode] | |||
else: | |||
losses_decode = self.loss_decode | |||
for loss_decode in losses_decode: | |||
if loss_decode.loss_name not in loss: | |||
loss[loss_decode.loss_name] = loss_decode( | |||
seg_logit, | |||
seg_label, | |||
weight=seg_weight, | |||
ignore_index=self.ignore_index) | |||
else: | |||
loss[loss_decode.loss_name] += loss_decode( | |||
seg_logit, | |||
seg_label, | |||
weight=seg_weight, | |||
ignore_index=self.ignore_index) | |||
loss['acc_seg'] = accuracy( | |||
seg_logit, seg_label, ignore_index=self.ignore_index) | |||
return loss |
@@ -0,0 +1,581 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import copy | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init | |||
from mmcv.cnn.bricks.transformer import (build_positional_encoding, | |||
build_transformer_layer_sequence) | |||
from mmcv.ops import point_sample | |||
from mmcv.runner import ModuleList, force_fp32 | |||
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean | |||
from mmdet.models.builder import HEADS, build_loss | |||
from mmdet.models.utils import get_uncertain_point_coords_with_randomness | |||
from .base_decode_head import BaseDecodeHead | |||
@HEADS.register_module() | |||
class Mask2FormerHeadFromMMSeg(BaseDecodeHead): | |||
"""Implements the Mask2Former head. | |||
See `Masked-attention Mask Transformer for Universal Image | |||
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. | |||
Args: | |||
in_channels (list[int]): Number of channels in the input feature map. | |||
feat_channels (int): Number of channels for features. | |||
out_channels (int): Number of channels for output. | |||
num_things_classes (int): Number of things. | |||
num_stuff_classes (int): Number of stuff. | |||
num_queries (int): Number of query in Transformer decoder. | |||
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel | |||
decoder. Defaults to None. | |||
enforce_decoder_input_project (bool, optional): Whether to add | |||
a layer to change the embed_dim of tranformer encoder in | |||
pixel decoder to the embed_dim of transformer decoder. | |||
Defaults to False. | |||
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for | |||
transformer decoder. Defaults to None. | |||
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for | |||
transformer decoder position encoding. Defaults to None. | |||
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification | |||
loss. Defaults to None. | |||
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. | |||
Defaults to None. | |||
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. | |||
Defaults to None. | |||
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of | |||
Mask2Former head. | |||
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of | |||
Mask2Former head. | |||
init_cfg (dict or list[dict], optional): Initialization config dict. | |||
Defaults to None. | |||
""" | |||
def __init__(self, | |||
in_channels, | |||
feat_channels, | |||
out_channels, | |||
num_things_classes=80, | |||
num_stuff_classes=53, | |||
num_queries=100, | |||
num_transformer_feat_level=3, | |||
pixel_decoder=None, | |||
enforce_decoder_input_project=False, | |||
transformer_decoder=None, | |||
positional_encoding=None, | |||
loss_cls=None, | |||
loss_mask=None, | |||
loss_dice=None, | |||
train_cfg=None, | |||
test_cfg=None, | |||
init_cfg=None, | |||
**kwargs): | |||
super(Mask2FormerHeadFromMMSeg, self).__init__( | |||
in_channels=in_channels, | |||
channels=feat_channels, | |||
num_classes=(num_things_classes + num_stuff_classes), | |||
init_cfg=init_cfg, | |||
input_transform='multiple_select', | |||
**kwargs) | |||
self.num_things_classes = num_things_classes | |||
self.num_stuff_classes = num_stuff_classes | |||
self.num_classes = self.num_things_classes + self.num_stuff_classes | |||
self.num_queries = num_queries | |||
self.num_transformer_feat_level = num_transformer_feat_level | |||
self.num_heads = transformer_decoder.transformerlayers. \ | |||
attn_cfgs.num_heads | |||
self.num_transformer_decoder_layers = transformer_decoder.num_layers | |||
assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level | |||
pixel_decoder_ = copy.deepcopy(pixel_decoder) | |||
pixel_decoder_.update( | |||
in_channels=in_channels, | |||
feat_channels=feat_channels, | |||
out_channels=out_channels) | |||
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] | |||
self.transformer_decoder = build_transformer_layer_sequence( | |||
transformer_decoder) | |||
self.decoder_embed_dims = self.transformer_decoder.embed_dims | |||
self.decoder_input_projs = ModuleList() | |||
# from low resolution to high resolution | |||
for _ in range(num_transformer_feat_level): | |||
if (self.decoder_embed_dims != feat_channels | |||
or enforce_decoder_input_project): | |||
self.decoder_input_projs.append( | |||
Conv2d( | |||
feat_channels, self.decoder_embed_dims, kernel_size=1)) | |||
else: | |||
self.decoder_input_projs.append(nn.Identity()) | |||
self.decoder_positional_encoding = build_positional_encoding( | |||
positional_encoding) | |||
self.query_embed = nn.Embedding(self.num_queries, feat_channels) | |||
self.query_feat = nn.Embedding(self.num_queries, feat_channels) | |||
# from low resolution to high resolution | |||
self.level_embed = nn.Embedding(self.num_transformer_feat_level, | |||
feat_channels) | |||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |||
self.mask_embed = nn.Sequential( | |||
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |||
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |||
nn.Linear(feat_channels, out_channels)) | |||
self.conv_seg = None # fix a bug here (conv_seg is not used) | |||
self.test_cfg = test_cfg | |||
self.train_cfg = train_cfg | |||
if train_cfg: | |||
self.assigner = build_assigner(self.train_cfg.assigner) | |||
self.sampler = build_sampler(self.train_cfg.sampler, context=self) | |||
self.num_points = self.train_cfg.get('num_points', 12544) | |||
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) | |||
self.importance_sample_ratio = self.train_cfg.get( | |||
'importance_sample_ratio', 0.75) | |||
self.class_weight = loss_cls.class_weight | |||
self.loss_cls = build_loss(loss_cls) | |||
self.loss_mask = build_loss(loss_mask) | |||
self.loss_dice = build_loss(loss_dice) | |||
def init_weights(self): | |||
for m in self.decoder_input_projs: | |||
if isinstance(m, Conv2d): | |||
caffe2_xavier_init(m, bias=0) | |||
self.pixel_decoder.init_weights() | |||
for p in self.transformer_decoder.parameters(): | |||
if p.dim() > 1: | |||
nn.init.xavier_normal_(p) | |||
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, | |||
gt_masks_list, img_metas): | |||
"""Compute classification and mask targets for all images for a decoder | |||
layer. | |||
Args: | |||
cls_scores_list (list[Tensor]): Mask score logits from a single | |||
decoder layer for all images. Each with shape [num_queries, | |||
cls_out_channels]. | |||
mask_preds_list (list[Tensor]): Mask logits from a single decoder | |||
layer for all images. Each with shape [num_queries, h, w]. | |||
gt_labels_list (list[Tensor]): Ground truth class indices for all | |||
images. Each with shape (n, ), n is the sum of number of stuff | |||
type and number of instance in a image. | |||
gt_masks_list (list[Tensor]): Ground truth mask for each image, | |||
each with shape (n, h, w). | |||
img_metas (list[dict]): List of image meta information. | |||
Returns: | |||
tuple[list[Tensor]]: a tuple containing the following targets. | |||
- labels_list (list[Tensor]): Labels of all images. | |||
Each with shape [num_queries, ]. | |||
- label_weights_list (list[Tensor]): Label weights of all | |||
images.Each with shape [num_queries, ]. | |||
- mask_targets_list (list[Tensor]): Mask targets of all images. | |||
Each with shape [num_queries, h, w]. | |||
- mask_weights_list (list[Tensor]): Mask weights of all images. | |||
Each with shape [num_queries, ]. | |||
- num_total_pos (int): Number of positive samples in all | |||
images. | |||
- num_total_neg (int): Number of negative samples in all | |||
images. | |||
""" | |||
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |||
pos_inds_list, | |||
neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, | |||
mask_preds_list, gt_labels_list, | |||
gt_masks_list, img_metas) | |||
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) | |||
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) | |||
return (labels_list, label_weights_list, mask_targets_list, | |||
mask_weights_list, num_total_pos, num_total_neg) | |||
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, | |||
img_metas): | |||
"""Compute classification and mask targets for one image. | |||
Args: | |||
cls_score (Tensor): Mask score logits from a single decoder layer | |||
for one image. Shape (num_queries, cls_out_channels). | |||
mask_pred (Tensor): Mask logits for a single decoder layer for one | |||
image. Shape (num_queries, h, w). | |||
gt_labels (Tensor): Ground truth class indices for one image with | |||
shape (num_gts, ). | |||
gt_masks (Tensor): Ground truth mask for each image, each with | |||
shape (num_gts, h, w). | |||
img_metas (dict): Image informtation. | |||
Returns: | |||
tuple[Tensor]: A tuple containing the following for one image. | |||
- labels (Tensor): Labels of each image. \ | |||
shape (num_queries, ). | |||
- label_weights (Tensor): Label weights of each image. \ | |||
shape (num_queries, ). | |||
- mask_targets (Tensor): Mask targets of each image. \ | |||
shape (num_queries, h, w). | |||
- mask_weights (Tensor): Mask weights of each image. \ | |||
shape (num_queries, ). | |||
- pos_inds (Tensor): Sampled positive indices for each \ | |||
image. | |||
- neg_inds (Tensor): Sampled negative indices for each \ | |||
image. | |||
""" | |||
# sample points | |||
num_queries = cls_score.shape[0] | |||
num_gts = gt_labels.shape[0] | |||
point_coords = torch.rand((1, self.num_points, 2), | |||
device=cls_score.device) | |||
# shape (num_queries, num_points) | |||
mask_points_pred = point_sample( | |||
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, | |||
1)).squeeze(1) | |||
# shape (num_gts, num_points) | |||
gt_points_masks = point_sample( | |||
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, | |||
1)).squeeze(1) | |||
# assign and sample | |||
assign_result = self.assigner.assign(cls_score, mask_points_pred, | |||
gt_labels, gt_points_masks, | |||
img_metas) | |||
sampling_result = self.sampler.sample(assign_result, mask_pred, | |||
gt_masks) | |||
pos_inds = sampling_result.pos_inds | |||
neg_inds = sampling_result.neg_inds | |||
# label target | |||
labels = gt_labels.new_full((self.num_queries, ), | |||
self.num_classes, | |||
dtype=torch.long) | |||
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] | |||
label_weights = gt_labels.new_ones((self.num_queries, )) | |||
# mask target | |||
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] | |||
mask_weights = mask_pred.new_zeros((self.num_queries, )) | |||
mask_weights[pos_inds] = 1.0 | |||
return (labels, label_weights, mask_targets, mask_weights, pos_inds, | |||
neg_inds) | |||
def loss_single(self, cls_scores, mask_preds, gt_labels_list, | |||
gt_masks_list, img_metas): | |||
"""Loss function for outputs from a single decoder layer. | |||
Args: | |||
cls_scores (Tensor): Mask score logits from a single decoder layer | |||
for all images. Shape (batch_size, num_queries, | |||
cls_out_channels). Note `cls_out_channels` should includes | |||
background. | |||
mask_preds (Tensor): Mask logits for a pixel decoder for all | |||
images. Shape (batch_size, num_queries, h, w). | |||
gt_labels_list (list[Tensor]): Ground truth class indices for each | |||
image, each with shape (num_gts, ). | |||
gt_masks_list (list[Tensor]): Ground truth mask for each image, | |||
each with shape (num_gts, h, w). | |||
img_metas (list[dict]): List of image meta information. | |||
Returns: | |||
tuple[Tensor]: Loss components for outputs from a single \ | |||
decoder layer. | |||
""" | |||
num_imgs = cls_scores.size(0) | |||
cls_scores_list = [cls_scores[i] for i in range(num_imgs)] | |||
mask_preds_list = [mask_preds[i] for i in range(num_imgs)] | |||
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |||
num_total_pos, | |||
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, | |||
gt_labels_list, gt_masks_list, | |||
img_metas) | |||
# shape (batch_size, num_queries) | |||
labels = torch.stack(labels_list, dim=0) | |||
# shape (batch_size, num_queries) | |||
label_weights = torch.stack(label_weights_list, dim=0) | |||
# shape (num_total_gts, h, w) | |||
mask_targets = torch.cat(mask_targets_list, dim=0) | |||
# shape (batch_size, num_queries) | |||
mask_weights = torch.stack(mask_weights_list, dim=0) | |||
# classfication loss | |||
# shape (batch_size * num_queries, ) | |||
cls_scores = cls_scores.flatten(0, 1) | |||
labels = labels.flatten(0, 1) | |||
label_weights = label_weights.flatten(0, 1) | |||
class_weight = cls_scores.new_tensor(self.class_weight) | |||
loss_cls = self.loss_cls( | |||
cls_scores, | |||
labels, | |||
label_weights, | |||
avg_factor=class_weight[labels].sum()) | |||
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) | |||
num_total_masks = max(num_total_masks, 1) | |||
# extract positive ones | |||
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) | |||
mask_preds = mask_preds[mask_weights > 0] | |||
if mask_targets.shape[0] == 0: | |||
# zero match | |||
loss_dice = mask_preds.sum() | |||
loss_mask = mask_preds.sum() | |||
return loss_cls, loss_mask, loss_dice | |||
with torch.no_grad(): | |||
points_coords = get_uncertain_point_coords_with_randomness( | |||
mask_preds.unsqueeze(1), None, self.num_points, | |||
self.oversample_ratio, self.importance_sample_ratio) | |||
# shape (num_total_gts, h, w) -> (num_total_gts, num_points) | |||
mask_point_targets = point_sample( | |||
mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) | |||
# shape (num_queries, h, w) -> (num_queries, num_points) | |||
mask_point_preds = point_sample( | |||
mask_preds.unsqueeze(1), points_coords).squeeze(1) | |||
# dice loss | |||
loss_dice = self.loss_dice( | |||
mask_point_preds, mask_point_targets, avg_factor=num_total_masks) | |||
# mask loss | |||
# shape (num_queries, num_points) -> (num_queries * num_points, ) | |||
mask_point_preds = mask_point_preds.reshape(-1, 1) | |||
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) | |||
mask_point_targets = mask_point_targets.reshape(-1) | |||
loss_mask = self.loss_mask( | |||
mask_point_preds, | |||
mask_point_targets, | |||
avg_factor=num_total_masks * self.num_points) | |||
return loss_cls, loss_mask, loss_dice | |||
@force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) | |||
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, | |||
gt_masks_list, img_metas): | |||
"""Loss function. | |||
Args: | |||
all_cls_scores (Tensor): Classification scores for all decoder | |||
layers with shape [num_decoder, batch_size, num_queries, | |||
cls_out_channels]. | |||
all_mask_preds (Tensor): Mask scores for all decoder layers with | |||
shape [num_decoder, batch_size, num_queries, h, w]. | |||
gt_labels_list (list[Tensor]): Ground truth class indices for each | |||
image with shape (n, ). n is the sum of number of stuff type | |||
and number of instance in a image. | |||
gt_masks_list (list[Tensor]): Ground truth mask for each image with | |||
shape (n, h, w). | |||
img_metas (list[dict]): List of image meta information. | |||
Returns: | |||
dict[str, Tensor]: A dictionary of loss components. | |||
""" | |||
num_dec_layers = len(all_cls_scores) | |||
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] | |||
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] | |||
img_metas_list = [img_metas for _ in range(num_dec_layers)] | |||
losses_cls, losses_mask, losses_dice = multi_apply( | |||
self.loss_single, all_cls_scores, all_mask_preds, | |||
all_gt_labels_list, all_gt_masks_list, img_metas_list) | |||
loss_dict = dict() | |||
# loss from the last decoder layer | |||
loss_dict['loss_cls'] = losses_cls[-1] | |||
loss_dict['loss_mask'] = losses_mask[-1] | |||
loss_dict['loss_dice'] = losses_dice[-1] | |||
# loss from other decoder layers | |||
num_dec_layer = 0 | |||
for loss_cls_i, loss_mask_i, loss_dice_i in zip( | |||
losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): | |||
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i | |||
loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i | |||
loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i | |||
num_dec_layer += 1 | |||
return loss_dict | |||
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): | |||
"""Forward for head part which is called after every decoder layer. | |||
Args: | |||
decoder_out (Tensor): in shape (num_queries, batch_size, c). | |||
mask_feature (Tensor): in shape (batch_size, c, h, w). | |||
attn_mask_target_size (tuple[int, int]): target attention | |||
mask size. | |||
Returns: | |||
tuple: A tuple contain three elements. | |||
- cls_pred (Tensor): Classification scores in shape \ | |||
(batch_size, num_queries, cls_out_channels). \ | |||
Note `cls_out_channels` should includes background. | |||
- mask_pred (Tensor): Mask scores in shape \ | |||
(batch_size, num_queries,h, w). | |||
- attn_mask (Tensor): Attention mask in shape \ | |||
(batch_size * num_heads, num_queries, h, w). | |||
""" | |||
decoder_out = self.transformer_decoder.post_norm(decoder_out) | |||
decoder_out = decoder_out.transpose(0, 1) | |||
# shape (num_queries, batch_size, c) | |||
cls_pred = self.cls_embed(decoder_out) | |||
# shape (num_queries, batch_size, c) | |||
mask_embed = self.mask_embed(decoder_out) | |||
# shape (num_queries, batch_size, h, w) | |||
mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) | |||
attn_mask = F.interpolate( | |||
mask_pred, | |||
attn_mask_target_size, | |||
mode='bilinear', | |||
align_corners=False) | |||
# shape (num_queries, batch_size, h, w) -> | |||
# (batch_size * num_head, num_queries, h, w) | |||
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( | |||
(1, self.num_heads, 1, 1)).flatten(0, 1) | |||
attn_mask = attn_mask.sigmoid() < 0.5 | |||
attn_mask = attn_mask.detach() | |||
return cls_pred, mask_pred, attn_mask | |||
def forward(self, feats, img_metas): | |||
"""Forward function. | |||
Args: | |||
feats (list[Tensor]): Multi scale Features from the | |||
upstream network, each is a 4D-tensor. | |||
img_metas (list[dict]): List of image information. | |||
Returns: | |||
tuple: A tuple contains two elements. | |||
- cls_pred_list (list[Tensor)]: Classification logits \ | |||
for each decoder layer. Each is a 3D-tensor with shape \ | |||
(batch_size, num_queries, cls_out_channels). \ | |||
Note `cls_out_channels` should includes background. | |||
- mask_pred_list (list[Tensor]): Mask logits for each \ | |||
decoder layer. Each with shape (batch_size, num_queries, \ | |||
h, w). | |||
""" | |||
batch_size = len(img_metas) | |||
mask_features, multi_scale_memorys = self.pixel_decoder(feats) | |||
# multi_scale_memorys (from low resolution to high resolution) | |||
decoder_inputs = [] | |||
decoder_positional_encodings = [] | |||
for i in range(self.num_transformer_feat_level): | |||
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) | |||
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |||
decoder_input = decoder_input.flatten(2).permute(2, 0, 1) | |||
level_embed = self.level_embed.weight[i].view(1, 1, -1) | |||
decoder_input = decoder_input + level_embed | |||
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |||
mask = decoder_input.new_zeros( | |||
(batch_size, ) + multi_scale_memorys[i].shape[-2:], | |||
dtype=torch.bool) | |||
decoder_positional_encoding = self.decoder_positional_encoding( | |||
mask) | |||
decoder_positional_encoding = decoder_positional_encoding.flatten( | |||
2).permute(2, 0, 1) | |||
decoder_inputs.append(decoder_input) | |||
decoder_positional_encodings.append(decoder_positional_encoding) | |||
# shape (num_queries, c) -> (num_queries, batch_size, c) | |||
query_feat = self.query_feat.weight.unsqueeze(1).repeat( | |||
(1, batch_size, 1)) | |||
query_embed = self.query_embed.weight.unsqueeze(1).repeat( | |||
(1, batch_size, 1)) | |||
cls_pred_list = [] | |||
mask_pred_list = [] | |||
cls_pred, mask_pred, attn_mask = self.forward_head( | |||
query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) | |||
cls_pred_list.append(cls_pred) | |||
mask_pred_list.append(mask_pred) | |||
for i in range(self.num_transformer_decoder_layers): | |||
level_idx = i % self.num_transformer_feat_level | |||
# if a mask is all True(all background), then set it all False. | |||
attn_mask[torch.where( | |||
attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |||
# cross_attn + self_attn | |||
layer = self.transformer_decoder.layers[i] | |||
attn_masks = [attn_mask, None] | |||
query_feat = layer( | |||
query=query_feat, | |||
key=decoder_inputs[level_idx], | |||
value=decoder_inputs[level_idx], | |||
query_pos=query_embed, | |||
key_pos=decoder_positional_encodings[level_idx], | |||
attn_masks=attn_masks, | |||
query_key_padding_mask=None, | |||
# here we do not apply masking on padded region | |||
key_padding_mask=None) | |||
cls_pred, mask_pred, attn_mask = self.forward_head( | |||
query_feat, mask_features, multi_scale_memorys[ | |||
(i + 1) % self.num_transformer_feat_level].shape[-2:]) | |||
cls_pred_list.append(cls_pred) | |||
mask_pred_list.append(mask_pred) | |||
return cls_pred_list, mask_pred_list | |||
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, | |||
gt_masks): | |||
"""Forward function for training mode. | |||
Args: | |||
x (list[Tensor]): Multi-level features from the upstream network, | |||
each is a 4D-tensor. | |||
img_metas (list[Dict]): List of image information. | |||
gt_semantic_seg (list[tensor]):Each element is the ground truth | |||
of semantic segmentation with the shape (N, H, W). | |||
train_cfg (dict): The training config, which not been used in | |||
maskformer. | |||
gt_labels (list[Tensor]): Each element is ground truth labels of | |||
each box, shape (num_gts,). | |||
gt_masks (list[BitmapMasks]): Each element is masks of instances | |||
of a image, shape (num_gts, h, w). | |||
Returns: | |||
losses (dict[str, Tensor]): a dictionary of loss components | |||
""" | |||
# forward | |||
all_cls_scores, all_mask_preds = self(x, img_metas) | |||
# loss | |||
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, | |||
img_metas) | |||
return losses | |||
def forward_test(self, inputs, img_metas, test_cfg): | |||
"""Test segment without test-time aumengtation. | |||
Only the output of last decoder layers was used. | |||
Args: | |||
inputs (list[Tensor]): Multi-level features from the | |||
upstream network, each is a 4D-tensor. | |||
img_metas (list[dict]): List of image information. | |||
test_cfg (dict): Testing config. | |||
Returns: | |||
seg_mask (Tensor): Predicted semantic segmentation logits. | |||
""" | |||
all_cls_scores, all_mask_preds = self(inputs, img_metas) | |||
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] | |||
ori_h, ori_w, _ = img_metas[0]['ori_shape'] | |||
# semantic inference | |||
cls_score = F.softmax(cls_score, dim=-1)[..., :-1] | |||
mask_pred = mask_pred.sigmoid() | |||
seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) | |||
return seg_mask |
@@ -0,0 +1,3 @@ | |||
from .encoder_decoder_mask2former import EncoderDecoderMask2Former | |||
__all__ = ['EncoderDecoderMask2Former'] |
@@ -0,0 +1,314 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import warnings | |||
from abc import ABCMeta, abstractmethod | |||
from collections import OrderedDict | |||
import mmcv | |||
import numpy as np | |||
import torch | |||
import torch.distributed as dist | |||
from mmcv.runner import BaseModule, auto_fp16 | |||
class BaseSegmentor(BaseModule, metaclass=ABCMeta): | |||
"""Base class for segmentors.""" | |||
def __init__(self, init_cfg=None): | |||
super(BaseSegmentor, self).__init__(init_cfg) | |||
self.fp16_enabled = False | |||
@property | |||
def with_neck(self): | |||
"""bool: whether the segmentor has neck""" | |||
return hasattr(self, 'neck') and self.neck is not None | |||
@property | |||
def with_auxiliary_head(self): | |||
"""bool: whether the segmentor has auxiliary head""" | |||
return hasattr(self, | |||
'auxiliary_head') and self.auxiliary_head is not None | |||
@property | |||
def with_decode_head(self): | |||
"""bool: whether the segmentor has decode head""" | |||
return hasattr(self, 'decode_head') and self.decode_head is not None | |||
@abstractmethod | |||
def extract_feat(self, imgs): | |||
"""Placeholder for extract features from images.""" | |||
pass | |||
@abstractmethod | |||
def encode_decode(self, img, img_metas): | |||
"""Placeholder for encode images with backbone and decode into a | |||
semantic segmentation map of the same size as input.""" | |||
pass | |||
@abstractmethod | |||
def forward_train(self, imgs, img_metas, **kwargs): | |||
"""Placeholder for Forward function for training.""" | |||
pass | |||
@abstractmethod | |||
def simple_test(self, img, img_meta, **kwargs): | |||
"""Placeholder for single image test.""" | |||
pass | |||
@abstractmethod | |||
def aug_test(self, imgs, img_metas, **kwargs): | |||
"""Placeholder for augmentation test.""" | |||
pass | |||
def forward_test(self, imgs, img_metas, **kwargs): | |||
""" | |||
Args: | |||
imgs (List[Tensor]): the outer list indicates test-time | |||
augmentations and inner Tensor should have a shape NxCxHxW, | |||
which contains all images in the batch. | |||
img_metas (List[List[dict]]): the outer list indicates test-time | |||
augs (multiscale, flip, etc.) and the inner list indicates | |||
images in a batch. | |||
""" | |||
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: | |||
if not isinstance(var, list): | |||
raise TypeError(f'{name} must be a list, but got ' | |||
f'{type(var)}') | |||
num_augs = len(imgs) | |||
if num_augs != len(img_metas): | |||
raise ValueError(f'num of augmentations ({len(imgs)}) != ' | |||
f'num of image meta ({len(img_metas)})') | |||
# all images in the same aug batch all of the same ori_shape and pad | |||
# shape | |||
def tensor_to_tuple(input_tensor): | |||
return tuple(input_tensor.cpu().numpy()) | |||
for img_meta in img_metas: | |||
ori_shapes = [_['ori_shape'] for _ in img_meta] | |||
if isinstance(ori_shapes[0], torch.Tensor): | |||
assert all( | |||
tensor_to_tuple(shape) == tensor_to_tuple(ori_shapes[0]) | |||
for shape in ori_shapes) | |||
else: | |||
assert all(shape == ori_shapes[0] for shape in ori_shapes) | |||
img_shapes = [_['img_shape'] for _ in img_meta] | |||
if isinstance(img_shapes[0], torch.Tensor): | |||
assert all( | |||
tensor_to_tuple(shape) == tensor_to_tuple(img_shapes[0]) | |||
for shape in img_shapes) | |||
else: | |||
assert all(shape == img_shapes[0] for shape in img_shapes) | |||
pad_shapes = [_['pad_shape'] for _ in img_meta] | |||
if isinstance(pad_shapes[0], torch.Tensor): | |||
assert all( | |||
tensor_to_tuple(shape) == tensor_to_tuple(pad_shapes[0]) | |||
for shape in pad_shapes) | |||
else: | |||
assert all(shape == pad_shapes[0] for shape in pad_shapes) | |||
if num_augs == 1: | |||
return self.simple_test(imgs[0], img_metas[0], **kwargs) | |||
else: | |||
return self.aug_test(imgs, img_metas, **kwargs) | |||
@auto_fp16(apply_to=('img', )) | |||
def forward(self, img, img_metas, return_loss=True, **kwargs): | |||
"""Calls either :func:`forward_train` or :func:`forward_test` depending | |||
on whether ``return_loss`` is ``True``. | |||
Note this setting will change the expected inputs. When | |||
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor | |||
and List[dict]), and when ``resturn_loss=False``, img and img_meta | |||
should be double nested (i.e. List[Tensor], List[List[dict]]), with | |||
the outer list indicating test time augmentations. | |||
""" | |||
if return_loss: | |||
return self.forward_train(img, img_metas, **kwargs) | |||
else: | |||
return self.forward_test(img, img_metas, **kwargs) | |||
def train_step(self, data_batch, optimizer, **kwargs): | |||
"""The iteration step during training. | |||
This method defines an iteration step during training, except for the | |||
back propagation and optimizer updating, which are done in an optimizer | |||
hook. Note that in some complicated cases or models, the whole process | |||
including back propagation and optimizer updating is also defined in | |||
this method, such as GAN. | |||
Args: | |||
data (dict): The output of dataloader. | |||
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of | |||
runner is passed to ``train_step()``. This argument is unused | |||
and reserved. | |||
Returns: | |||
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, | |||
``num_samples``. | |||
``loss`` is a tensor for back propagation, which can be a | |||
weighted sum of multiple losses. | |||
``log_vars`` contains all the variables to be sent to the | |||
logger. | |||
``num_samples`` indicates the batch size (when the model is | |||
DDP, it means the batch size on each GPU), which is used for | |||
averaging the logs. | |||
""" | |||
losses = self(**data_batch) | |||
loss, log_vars = self._parse_losses(losses) | |||
outputs = dict( | |||
loss=loss, | |||
log_vars=log_vars, | |||
num_samples=len(data_batch['img_metas'])) | |||
return outputs | |||
def val_step(self, data_batch, optimizer=None, **kwargs): | |||
"""The iteration step during validation. | |||
This method shares the same signature as :func:`train_step`, but used | |||
during val epochs. Note that the evaluation after training epochs is | |||
not implemented with this method, but an evaluation hook. | |||
""" | |||
losses = self(**data_batch) | |||
loss, log_vars = self._parse_losses(losses) | |||
log_vars_ = dict() | |||
for loss_name, loss_value in log_vars.items(): | |||
k = loss_name + '_val' | |||
log_vars_[k] = loss_value | |||
outputs = dict( | |||
loss=loss, | |||
log_vars=log_vars_, | |||
num_samples=len(data_batch['img_metas'])) | |||
return outputs | |||
@staticmethod | |||
def _parse_losses(losses): | |||
"""Parse the raw outputs (losses) of the network. | |||
Args: | |||
losses (dict): Raw output of the network, which usually contain | |||
losses and other necessary information. | |||
Returns: | |||
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor | |||
which may be a weighted sum of all losses, log_vars contains | |||
all the variables to be sent to the logger. | |||
""" | |||
log_vars = OrderedDict() | |||
for loss_name, loss_value in losses.items(): | |||
if isinstance(loss_value, torch.Tensor): | |||
log_vars[loss_name] = loss_value.mean() | |||
elif isinstance(loss_value, list): | |||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) | |||
else: | |||
raise TypeError( | |||
f'{loss_name} is not a tensor or list of tensors') | |||
loss = sum(_value for _key, _value in log_vars.items() | |||
if 'loss' in _key) | |||
# If the loss_vars has different length, raise assertion error | |||
# to prevent GPUs from infinite waiting. | |||
if dist.is_available() and dist.is_initialized(): | |||
log_var_length = torch.tensor(len(log_vars), device=loss.device) | |||
dist.all_reduce(log_var_length) | |||
message = (f'rank {dist.get_rank()}' | |||
+ f' len(log_vars): {len(log_vars)}' + ' keys: ' | |||
+ ','.join(log_vars.keys()) + '\n') | |||
assert log_var_length == len(log_vars) * dist.get_world_size(), \ | |||
'loss log variables are different across GPUs!\n' + message | |||
log_vars['loss'] = loss | |||
for loss_name, loss_value in log_vars.items(): | |||
# reduce loss when distributed training | |||
if dist.is_available() and dist.is_initialized(): | |||
loss_value = loss_value.data.clone() | |||
dist.all_reduce(loss_value.div_(dist.get_world_size())) | |||
log_vars[loss_name] = loss_value.item() | |||
return loss, log_vars | |||
def show_result(self, | |||
img, | |||
result, | |||
palette=None, | |||
win_name='', | |||
show=False, | |||
wait_time=0, | |||
out_file=None, | |||
opacity=0.5): | |||
"""Draw `result` over `img`. | |||
Args: | |||
img (str or Tensor): The image to be displayed. | |||
result (Tensor): The semantic segmentation results to draw over | |||
`img`. | |||
palette (list[list[int]]] | np.ndarray | None): The palette of | |||
segmentation map. If None is given, random palette will be | |||
generated. Default: None | |||
win_name (str): The window name. | |||
wait_time (int): Value of waitKey param. | |||
Default: 0. | |||
show (bool): Whether to show the image. | |||
Default: False. | |||
out_file (str or None): The filename to write the image. | |||
Default: None. | |||
opacity(float): Opacity of painted segmentation map. | |||
Default 0.5. | |||
Must be in (0, 1] range. | |||
Returns: | |||
img (Tensor): Only if not `show` or `out_file` | |||
""" | |||
img = mmcv.imread(img) | |||
img = img.copy() | |||
seg = result[0] | |||
if palette is None: | |||
if self.PALETTE is None: | |||
# Get random state before set seed, | |||
# and restore random state later. | |||
# It will prevent loss of randomness, as the palette | |||
# may be different in each iteration if not specified. | |||
# See: https://github.com/open-mmlab/mmdetection/issues/5844 | |||
state = np.random.get_state() | |||
np.random.seed(42) | |||
# random palette | |||
palette = np.random.randint( | |||
0, 255, size=(len(self.CLASSES), 3)) | |||
np.random.set_state(state) | |||
else: | |||
palette = self.PALETTE | |||
palette = np.array(palette) | |||
assert palette.shape[0] == len(self.CLASSES) | |||
assert palette.shape[1] == 3 | |||
assert len(palette.shape) == 2 | |||
assert 0 < opacity <= 1.0 | |||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |||
for label, color in enumerate(palette): | |||
color_seg[seg == label, :] = color | |||
# convert to BGR | |||
color_seg = color_seg[..., ::-1] | |||
img = img * (1 - opacity) + color_seg * opacity | |||
img = img.astype(np.uint8) | |||
# if out_file specified, do not show image in window | |||
if out_file is not None: | |||
show = False | |||
if show: | |||
mmcv.imshow(img, win_name, wait_time) | |||
if out_file is not None: | |||
mmcv.imwrite(img, out_file) | |||
if not (show or out_file): | |||
warnings.warn('show==False and out_file is not specified, only ' | |||
'result image will be returned') | |||
return img |
@@ -0,0 +1,303 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from mmdet.models import builder | |||
from mmdet.models.builder import DETECTORS | |||
from ...utils import add_prefix, seg_resize | |||
from .base_segmentor import BaseSegmentor | |||
@DETECTORS.register_module() | |||
class EncoderDecoderMask2Former(BaseSegmentor): | |||
"""Encoder Decoder segmentors. | |||
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. | |||
Note that auxiliary_head is only used for deep supervision during training, | |||
which could be dumped during inference. | |||
""" | |||
def __init__(self, | |||
backbone, | |||
decode_head, | |||
neck=None, | |||
auxiliary_head=None, | |||
train_cfg=None, | |||
test_cfg=None, | |||
pretrained=None, | |||
init_cfg=None): | |||
super(EncoderDecoderMask2Former, self).__init__(init_cfg) | |||
if pretrained is not None: | |||
assert backbone.get('pretrained') is None, \ | |||
'both backbone and segmentor set pretrained weight' | |||
backbone.pretrained = pretrained | |||
self.backbone = builder.build_backbone(backbone) | |||
if neck is not None: | |||
self.neck = builder.build_neck(neck) | |||
decode_head.update(train_cfg=train_cfg) | |||
decode_head.update(test_cfg=test_cfg) | |||
self._init_decode_head(decode_head) | |||
self._init_auxiliary_head(auxiliary_head) | |||
self.train_cfg = train_cfg | |||
self.test_cfg = test_cfg | |||
assert self.with_decode_head | |||
def _init_decode_head(self, decode_head): | |||
"""Initialize ``decode_head``""" | |||
self.decode_head = builder.build_head(decode_head) | |||
self.align_corners = self.decode_head.align_corners | |||
self.num_classes = self.decode_head.num_classes | |||
def _init_auxiliary_head(self, auxiliary_head): | |||
"""Initialize ``auxiliary_head``""" | |||
if auxiliary_head is not None: | |||
if isinstance(auxiliary_head, list): | |||
self.auxiliary_head = nn.ModuleList() | |||
for head_cfg in auxiliary_head: | |||
self.auxiliary_head.append(builder.build_head(head_cfg)) | |||
else: | |||
self.auxiliary_head = builder.build_head(auxiliary_head) | |||
def extract_feat(self, img): | |||
"""Extract features from images.""" | |||
x = self.backbone(img) | |||
if self.with_neck: | |||
x = self.neck(x) | |||
return x | |||
def encode_decode(self, img, img_metas): | |||
"""Encode images with backbone and decode into a semantic segmentation | |||
map of the same size as input.""" | |||
x = self.extract_feat(img) | |||
out = self._decode_head_forward_test(x, img_metas) | |||
out = seg_resize( | |||
input=out, | |||
size=img.shape[2:], | |||
mode='bilinear', | |||
align_corners=self.align_corners) | |||
return out | |||
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, | |||
**kwargs): | |||
"""Run forward function and calculate loss for decode head in | |||
training.""" | |||
losses = dict() | |||
loss_decode = self.decode_head.forward_train(x, img_metas, | |||
gt_semantic_seg, **kwargs) | |||
losses.update(add_prefix(loss_decode, 'decode')) | |||
return losses | |||
def _decode_head_forward_test(self, x, img_metas): | |||
"""Run forward function and calculate loss for decode head in | |||
inference.""" | |||
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) | |||
return seg_logits | |||
def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): | |||
"""Run forward function and calculate loss for auxiliary head in | |||
training.""" | |||
losses = dict() | |||
if isinstance(self.auxiliary_head, nn.ModuleList): | |||
for idx, aux_head in enumerate(self.auxiliary_head): | |||
loss_aux = aux_head.forward_train(x, img_metas, | |||
gt_semantic_seg, | |||
self.train_cfg) | |||
losses.update(add_prefix(loss_aux, f'aux_{idx}')) | |||
else: | |||
loss_aux = self.auxiliary_head.forward_train( | |||
x, img_metas, gt_semantic_seg, self.train_cfg) | |||
losses.update(add_prefix(loss_aux, 'aux')) | |||
return losses | |||
def forward_dummy(self, img): | |||
"""Dummy forward function.""" | |||
seg_logit = self.encode_decode(img, None) | |||
return seg_logit | |||
def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): | |||
"""Forward function for training. | |||
Args: | |||
img (Tensor): Input images. | |||
img_metas (list[dict]): List of image info dict where each dict | |||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
`mmseg/datasets/pipelines/formatting.py:Collect`. | |||
gt_semantic_seg (Tensor): Semantic segmentation masks | |||
used if the architecture supports semantic segmentation task. | |||
Returns: | |||
dict[str, Tensor]: a dictionary of loss components | |||
""" | |||
x = self.extract_feat(img) | |||
losses = dict() | |||
loss_decode = self._decode_head_forward_train(x, img_metas, | |||
gt_semantic_seg, | |||
**kwargs) | |||
losses.update(loss_decode) | |||
if self.with_auxiliary_head: | |||
loss_aux = self._auxiliary_head_forward_train( | |||
x, img_metas, gt_semantic_seg) | |||
losses.update(loss_aux) | |||
return losses | |||
# TODO refactor | |||
def slide_inference(self, img, img_meta, rescale): | |||
"""Inference by sliding-window with overlap. | |||
If h_crop > h_img or w_crop > w_img, the small patch will be used to | |||
decode without padding. | |||
""" | |||
h_stride, w_stride = self.test_cfg.stride | |||
h_crop, w_crop = self.test_cfg.crop_size | |||
batch_size, _, h_img, w_img = img.size() | |||
num_classes = self.num_classes | |||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 | |||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 | |||
preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) | |||
count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) | |||
for h_idx in range(h_grids): | |||
for w_idx in range(w_grids): | |||
y1 = h_idx * h_stride | |||
x1 = w_idx * w_stride | |||
y2 = min(y1 + h_crop, h_img) | |||
x2 = min(x1 + w_crop, w_img) | |||
y1 = max(y2 - h_crop, 0) | |||
x1 = max(x2 - w_crop, 0) | |||
crop_img = img[:, :, y1:y2, x1:x2] | |||
crop_seg_logit = self.encode_decode(crop_img, img_meta) | |||
preds += F.pad(crop_seg_logit, | |||
(int(x1), int(preds.shape[3] - x2), int(y1), | |||
int(preds.shape[2] - y2))) | |||
count_mat[:, :, y1:y2, x1:x2] += 1 | |||
assert (count_mat == 0).sum() == 0 | |||
if torch.onnx.is_in_onnx_export(): | |||
# cast count_mat to constant while exporting to ONNX | |||
count_mat = torch.from_numpy( | |||
count_mat.cpu().detach().numpy()).to(device=img.device) | |||
preds = preds / count_mat | |||
def tensor_to_tuple(input_tensor): | |||
return tuple(input_tensor.cpu().numpy()) | |||
if rescale: | |||
preds = seg_resize( | |||
preds, | |||
size=tensor_to_tuple(img_meta[0]['ori_shape'])[:2] | |||
if isinstance(img_meta[0]['ori_shape'], torch.Tensor) else | |||
img_meta[0]['ori_shape'], | |||
mode='bilinear', | |||
align_corners=self.align_corners, | |||
warning=False) | |||
return preds | |||
def whole_inference(self, img, img_meta, rescale): | |||
"""Inference with full image.""" | |||
seg_logit = self.encode_decode(img, img_meta) | |||
if rescale: | |||
# support dynamic shape for onnx | |||
if torch.onnx.is_in_onnx_export(): | |||
size = img.shape[2:] | |||
else: | |||
size = img_meta[0]['ori_shape'][:2] | |||
seg_logit = seg_resize( | |||
seg_logit, | |||
size=size, | |||
mode='bilinear', | |||
align_corners=self.align_corners, | |||
warning=False) | |||
return seg_logit | |||
def inference(self, img, img_meta, rescale): | |||
"""Inference with slide/whole style. | |||
Args: | |||
img (Tensor): The input image of shape (N, 3, H, W). | |||
img_meta (dict): Image info dict where each dict has: 'img_shape', | |||
'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
`mmseg/datasets/pipelines/formatting.py:Collect`. | |||
rescale (bool): Whether rescale back to original shape. | |||
Returns: | |||
Tensor: The output segmentation map. | |||
""" | |||
assert self.test_cfg.mode in ['slide', 'whole'] | |||
ori_shape = img_meta[0]['ori_shape'] | |||
def tensor_to_tuple(input_tensor): | |||
return tuple(input_tensor.cpu().numpy()) | |||
if isinstance(ori_shape, torch.Tensor): | |||
assert all( | |||
tensor_to_tuple(_['ori_shape']) == tensor_to_tuple(ori_shape) | |||
for _ in img_meta) | |||
else: | |||
assert all(_['ori_shape'] == ori_shape for _ in img_meta) | |||
if self.test_cfg.mode == 'slide': | |||
seg_logit = self.slide_inference(img, img_meta, rescale) | |||
else: | |||
seg_logit = self.whole_inference(img, img_meta, rescale) | |||
output = F.softmax(seg_logit, dim=1) | |||
flip = img_meta[0]['flip'] | |||
if flip: | |||
flip_direction = img_meta[0]['flip_direction'] | |||
assert flip_direction in ['horizontal', 'vertical'] | |||
if flip_direction == 'horizontal': | |||
output = output.flip(dims=(3, )) | |||
elif flip_direction == 'vertical': | |||
output = output.flip(dims=(2, )) | |||
return output | |||
def simple_test(self, img, img_meta, rescale=True): | |||
"""Simple test with single image.""" | |||
seg_logit = self.inference(img, img_meta, rescale) | |||
seg_pred = seg_logit.argmax(dim=1) | |||
if torch.onnx.is_in_onnx_export(): | |||
# our inference backend only support 4D output | |||
seg_pred = seg_pred.unsqueeze(0) | |||
return seg_pred | |||
seg_pred = seg_pred.cpu().numpy() | |||
# unravel batch dim | |||
seg_pred = list(seg_pred) | |||
return seg_pred | |||
def aug_test(self, imgs, img_metas, rescale=True): | |||
"""Test with augmentations. | |||
Only rescale=True is supported. | |||
""" | |||
# aug_test rescale all imgs back to ori_shape for now | |||
assert rescale | |||
# to save memory, we get augmented seg logit inplace | |||
seg_logit = self.inference(imgs[0], img_metas[0], rescale) | |||
for i in range(1, len(imgs)): | |||
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) | |||
seg_logit += cur_seg_logit | |||
seg_logit /= len(imgs) | |||
seg_pred = seg_logit.argmax(dim=1) | |||
seg_pred = seg_pred.cpu().numpy() | |||
# unravel batch dim | |||
seg_pred = list(seg_pred) | |||
return seg_pred |
@@ -0,0 +1,7 @@ | |||
from .builder import build_pixel_sampler | |||
from .data_process_func import ResizeToMultiple | |||
from .seg_func import add_prefix, seg_resize | |||
__all__ = [ | |||
'seg_resize', 'add_prefix', 'build_pixel_sampler', 'ResizeToMultiple' | |||
] |
@@ -0,0 +1,11 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
from mmcv.utils import Registry, build_from_cfg | |||
PIXEL_SAMPLERS = Registry('pixel sampler') | |||
def build_pixel_sampler(cfg, **default_args): | |||
"""Build pixel sampler for segmentation map.""" | |||
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) |
@@ -0,0 +1,60 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
import mmcv | |||
from mmdet.datasets.builder import PIPELINES | |||
@PIPELINES.register_module() | |||
class ResizeToMultiple(object): | |||
"""Resize images & seg to multiple of divisor. | |||
Args: | |||
size_divisor (int): images and gt seg maps need to resize to multiple | |||
of size_divisor. Default: 32. | |||
interpolation (str, optional): The interpolation mode of image resize. | |||
Default: None | |||
""" | |||
def __init__(self, size_divisor=32, interpolation=None): | |||
self.size_divisor = size_divisor | |||
self.interpolation = interpolation | |||
def __call__(self, results): | |||
"""Call function to resize images, semantic segmentation map to | |||
multiple of size divisor. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Resized results, 'img_shape', 'pad_shape' keys are updated. | |||
""" | |||
# Align image to multiple of size divisor. | |||
img = results['img'] | |||
img = mmcv.imresize_to_multiple( | |||
img, | |||
self.size_divisor, | |||
scale_factor=1, | |||
interpolation=self.interpolation | |||
if self.interpolation else 'bilinear') | |||
results['img'] = img | |||
results['img_shape'] = img.shape | |||
results['pad_shape'] = img.shape | |||
# Align segmentation map to multiple of size divisor. | |||
for key in results.get('seg_fields', []): | |||
gt_seg = results[key] | |||
gt_seg = mmcv.imresize_to_multiple( | |||
gt_seg, | |||
self.size_divisor, | |||
scale_factor=1, | |||
interpolation='nearest') | |||
results[key] = gt_seg | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += (f'(size_divisor={self.size_divisor}, ' | |||
f'interpolation={self.interpolation})') | |||
return repr_str |
@@ -0,0 +1,48 @@ | |||
# The implementation refers to the VitAdapter | |||
# available at | |||
# https://github.com/czczup/ViT-Adapter.git | |||
import warnings | |||
import torch.nn.functional as F | |||
def seg_resize(input, | |||
size=None, | |||
scale_factor=None, | |||
mode='nearest', | |||
align_corners=None, | |||
warning=True): | |||
if warning: | |||
if size is not None and align_corners: | |||
input_h, input_w = tuple(int(x) for x in input.shape[2:]) | |||
output_h, output_w = tuple(int(x) for x in size) | |||
if output_h > input_h or output_w > input_w: | |||
if ((output_h > 1 and output_w > 1 and input_h > 1 | |||
and input_w > 1) and (output_h - 1) % (input_h - 1) | |||
and (output_w - 1) % (input_w - 1)): | |||
warnings.warn( | |||
f'When align_corners={align_corners}, ' | |||
'the output would more aligned if ' | |||
f'input size {(input_h, input_w)} is `x+1` and ' | |||
f'out size {(output_h, output_w)} is `nx+1`') | |||
return F.interpolate(input, size, scale_factor, mode, align_corners) | |||
def add_prefix(inputs, prefix): | |||
"""Add prefix for dict. | |||
Args: | |||
inputs (dict): The input dict with str keys. | |||
prefix (str): The prefix to add. | |||
Returns: | |||
dict: The dict with keys updated with ``prefix``. | |||
""" | |||
outputs = dict() | |||
for name, value in inputs.items(): | |||
outputs[f'{prefix}.{name}'] = value | |||
return outputs |
@@ -26,6 +26,7 @@ if TYPE_CHECKING: | |||
from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline | |||
from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||
from .image_reid_person_pipeline import ImageReidPersonPipeline | |||
from .image_semantic_segmentation_pipeline import ImageSemanticSegmentationPipeline | |||
from .image_style_transfer_pipeline import ImageStyleTransferPipeline | |||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | |||
@@ -66,6 +67,8 @@ else: | |||
'image_portrait_enhancement_pipeline': | |||
['ImagePortraitEnhancementPipeline'], | |||
'image_reid_person_pipeline': ['ImageReidPersonPipeline'], | |||
'image_semantic_segmentation_pipeline': | |||
['ImageSemanticSegmentationPipeline'], | |||
'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], | |||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
'image_to_image_translation_pipeline': | |||
@@ -0,0 +1,95 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_segmentation, | |||
module_name=Pipelines.image_semantic_segmentation) | |||
class ImageSemanticSegmentationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a image semantic segmentation pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
logger.info('semantic segmentation model, pipeline init') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
from mmdet.datasets.pipelines import Compose | |||
from mmcv.parallel import collate, scatter | |||
from mmdet.datasets import replace_ImageToTensor | |||
cfg = self.model.cfg | |||
# build the data pipeline | |||
if isinstance(input, str): | |||
# input is str, file names, pipeline loadimagefromfile | |||
# collect data | |||
data = dict(img_info=dict(filename=input), img_prefix=None) | |||
elif isinstance(input, PIL.Image.Image): # BGR | |||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
img = np.array(input)[:, :, ::-1] | |||
# collect data | |||
data = dict(img=img) | |||
elif isinstance(input, np.ndarray): | |||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
else: | |||
img = input | |||
# collect data | |||
data = dict(img=img) | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
# data = dict(img=input) | |||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |||
test_pipeline = Compose(cfg.data.test.pipeline) | |||
data = test_pipeline(data) | |||
# copy from mmdet_model collect data | |||
data = collate([data], samples_per_gpu=1) | |||
data['img_metas'] = [ | |||
img_metas.data[0] for img_metas in data['img_metas'] | |||
] | |||
data['img'] = [img.data[0] for img in data['img']] | |||
if next(self.model.parameters()).is_cuda: | |||
# scatter to specified GPU | |||
data = scatter(data, [next(self.model.parameters()).device])[0] | |||
return data | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
results = self.model.inference(input) | |||
return results | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
results = self.model.postprocess(inputs) | |||
outputs = { | |||
OutputKeys.MASKS: results[OutputKeys.MASKS], | |||
OutputKeys.LABELS: results[OutputKeys.LABELS], | |||
OutputKeys.SCORES: results[OutputKeys.SCORES] | |||
} | |||
return outputs |
@@ -153,3 +153,16 @@ def panoptic_seg_masks_to_image(masks): | |||
draw_img[mask] = color_mask | |||
return draw_img | |||
def semantic_seg_masks_to_image(masks): | |||
from mmdet.core.visualization.palette import get_palette | |||
mask_palette = get_palette('coco', 133) | |||
draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) | |||
for i, mask in enumerate(masks): | |||
color_mask = mask_palette[i] | |||
mask = mask.astype(bool) | |||
draw_img[mask] = color_mask | |||
return draw_img |
@@ -0,0 +1,54 @@ | |||
import unittest | |||
import cv2 | |||
import PIL | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
class ImageSemanticSegmentationTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_image_semantic_segmentation_panmerge(self): | |||
input_location = 'data/test/images/image_semantic_segmentation.jpg' | |||
model_id = 'damo/cv_swinL_semantic-segmentation_cocopanmerge' | |||
segmenter = pipeline(Tasks.image_segmentation, model=model_id) | |||
result = segmenter(input_location) | |||
draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
cv2.imwrite('result.jpg', draw_img) | |||
print('test_image_semantic_segmentation_panmerge DONE') | |||
PIL_array = PIL.Image.open(input_location) | |||
result = segmenter(PIL_array) | |||
draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
cv2.imwrite('result.jpg', draw_img) | |||
print('test_image_semantic_segmentation_panmerge_from_PIL DONE') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_image_semantic_segmentation_vitadapter(self): | |||
input_location = 'data/test/images/image_semantic_segmentation.jpg' | |||
model_id = 'damo/cv_vitadapter_semantic-segmentation_cocostuff164k' | |||
segmenter = pipeline(Tasks.image_segmentation, model=model_id) | |||
result = segmenter(input_location) | |||
draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
cv2.imwrite('result.jpg', draw_img) | |||
print('test_image_semantic_segmentation_vitadapter DONE') | |||
PIL_array = PIL.Image.open(input_location) | |||
result = segmenter(PIL_array) | |||
draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
cv2.imwrite('result.jpg', draw_img) | |||
print('test_image_semantic_segmentation_vitadapter_from_PIL DONE') | |||
if __name__ == '__main__': | |||
unittest.main() |