wendi.hwd yingda.chen 3 years ago
parent
commit
674e625e7c
28 changed files with 2655 additions and 3 deletions
  1. +3
    -0
      data/test/images/image_detection.jpg
  2. +3
    -0
      modelscope/metainfo.py
  3. +2
    -1
      modelscope/models/cv/__init__.py
  4. +22
    -0
      modelscope/models/cv/object_detection/__init__.py
  5. +92
    -0
      modelscope/models/cv/object_detection/mmdet_model.py
  6. +4
    -0
      modelscope/models/cv/object_detection/mmdet_ms/__init__.py
  7. +3
    -0
      modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py
  8. +626
    -0
      modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py
  9. +4
    -0
      modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py
  10. +48
    -0
      modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py
  11. +268
    -0
      modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py
  12. +3
    -0
      modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py
  13. +207
    -0
      modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py
  14. +8
    -0
      modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py
  15. +4
    -0
      modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py
  16. +229
    -0
      modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py
  17. +3
    -0
      modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py
  18. +414
    -0
      modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py
  19. +4
    -0
      modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py
  20. +558
    -0
      modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py
  21. +30
    -0
      modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py
  22. +5
    -2
      modelscope/pipelines/base.py
  23. +4
    -0
      modelscope/pipelines/builder.py
  24. +2
    -0
      modelscope/pipelines/cv/__init__.py
  25. +51
    -0
      modelscope/pipelines/cv/object_detection_pipeline.py
  26. +1
    -0
      modelscope/utils/constant.py
  27. +1
    -0
      requirements/cv.txt
  28. +56
    -0
      tests/pipelines/test_object_detection.py

+ 3
- 0
data/test/images/image_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0218020651b6cdcc0051563f75750c8200d34fc49bf34cc053cd59c1f13cad03
size 128624

+ 3
- 0
modelscope/metainfo.py View File

@@ -10,6 +10,7 @@ class Models(object):
Model name should only contain model info but not task info.
"""
# vision models
detection = 'detection'
scrfd = 'scrfd'
classification_model = 'ClassificationModel'
nafnet = 'nafnet'
@@ -69,6 +70,8 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
live_category = 'live-category'


+ 2
- 1
modelscope/models/cv/__init__.py View File

@@ -3,4 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon,
cmdssl_video_embedding, face_detection, face_generation,
image_classification, image_color_enhance, image_colorization,
image_denoise, image_instance_segmentation,
image_to_image_translation, super_resolution, virual_tryon)
image_to_image_translation, object_detection, super_resolution,
virual_tryon)

+ 22
- 0
modelscope/models/cv/object_detection/__init__.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .mmdet_model import DetectionModel

else:
_import_structure = {
'mmdet_model': ['DetectionModel'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 92
- 0
modelscope/models/cv/object_detection/mmdet_model.py View File

@@ -0,0 +1,92 @@
import os.path as osp

import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from .mmdet_ms.backbones import ViT
from .mmdet_ms.dense_heads import RPNNHead
from .mmdet_ms.necks import FPNF
from .mmdet_ms.roi_heads import FCNMaskNHead, Shared4Conv1FCBBoxNHead


@MODELS.register_module(Tasks.human_detection, module_name=Models.detection)
@MODELS.register_module(Tasks.object_detection, module_name=Models.detection)
class DetectionModel(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""str -- model file root."""
super().__init__(model_dir, *args, **kwargs)

from mmcv.runner import load_checkpoint
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector

model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
config_path = osp.join(model_dir, 'mmcv_config.py')
config = Config.from_file(config_path)
config.model.pretrained = None
self.model = build_detector(config.model)

checkpoint = load_checkpoint(
self.model, model_path, map_location='cpu')
self.class_names = checkpoint['meta']['CLASSES']
config.test_pipeline[0].type = 'LoadImageFromWebcam'
self.test_pipeline = Compose(
replace_ImageToTensor(config.test_pipeline))
self.model.cfg = config
self.model.eval()
self.score_thr = config.score_thr

def inference(self, data):
"""data is dict,contain img and img_metas,follow with mmdet."""

with torch.no_grad():
results = self.model(return_loss=False, rescale=True, **data)
return results

def preprocess(self, image):
"""image is numpy return is dict contain img and img_metas,follow with mmdet."""

from mmcv.parallel import collate, scatter
data = dict(img=image)
data = self.test_pipeline(data)
data = collate([data], samples_per_gpu=1)
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
data['img'] = [img.data[0] for img in data['img']]

if next(self.model.parameters()).is_cuda:
data = scatter(data, [next(self.model.parameters()).device])[0]

return data

def postprocess(self, inputs):

if isinstance(inputs[0], tuple):
bbox_result, _ = inputs[0]
else:
bbox_result, _ = inputs[0], None
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)

bbox_result = np.vstack(bbox_result)
scores = bbox_result[:, -1]
inds = scores > self.score_thr
if np.sum(np.array(inds).astype('int')) == 0:
return None, None, None
bboxes = bbox_result[inds, :]
labels = labels[inds]
scores = bboxes[:, 4]
bboxes = bboxes[:, 0:4]
labels = [self.class_names[i_label] for i_label in labels]
return bboxes, scores, labels

+ 4
- 0
modelscope/models/cv/object_detection/mmdet_ms/__init__.py View File

@@ -0,0 +1,4 @@
from .backbones import ViT
from .dense_heads import AnchorNHead, RPNNHead
from .necks import FPNF
from .utils import ConvModule_Norm, load_checkpoint

+ 3
- 0
modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py View File

@@ -0,0 +1,3 @@
from .vit import ViT

__all__ = ['ViT']

+ 626
- 0
modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py View File

@@ -0,0 +1,626 @@
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from mmdet.models.builder import BACKBONES
from mmdet.utils import get_root_logger
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

from ..utils import load_checkpoint


class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self):
return 'p={}'.format(self.drop_prob)


class Mlp(nn.Module):

def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop(x)
return x


class Attention(nn.Module):

def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
window_size=None,
attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
self.window_size = window_size
q_size = window_size[0]
rel_sp_dim = 2 * q_size - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, H, W, rel_pos_bias=None):
B, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = calc_rel_pos_spatial(attn, q, self.window_size,
self.window_size, self.rel_pos_h,
self.rel_pos_w)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x


def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows


def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x


def calc_rel_pos_spatial(
attn,
q,
q_shape,
k_shape,
rel_pos_h,
rel_pos_w,
):
"""
Spatial Relative Positional Embeddings.
"""
sp_idx = 0
q_h, q_w = q_shape
k_h, k_w = k_shape
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = (
torch.arange(q_h)[:, None] * q_h_ratio
- torch.arange(k_h)[None, :] * k_h_ratio)
dist_h += (k_h - 1) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = (
torch.arange(q_w)[:, None] * q_w_ratio
- torch.arange(k_w)[None, :] * k_w_ratio)
dist_w += (k_w - 1) * k_w_ratio
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh)
rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw)
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]).view(
B, -1, q_h * q_w, k_h * k_w)

return attn


class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""

def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None):

super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5

q_size = window_size[0]
rel_sp_dim = 2 * q_size - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.softmax = nn.Softmax(dim=-1)

def forward(self, x, H, W):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
x = x.reshape(B_, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size[1]
- W % self.window_size[1]) % self.window_size[1]
pad_b = (self.window_size[0]
- H % self.window_size[0]) % self.window_size[0]

x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape

x = window_partition(
x, self.window_size[0]) # nW*B, window_size, window_size, C
x = x.view(-1, self.window_size[1] * self.window_size[0],
C) # nW*B, window_size*window_size, C
B_w = x.shape[0]
N_w = x.shape[1]
qkv = self.qkv(x).reshape(B_w, N_w, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)

q = q * self.scale
attn = (q @ k.transpose(-2, -1))

attn = calc_rel_pos_spatial(attn, q, self.window_size,
self.window_size, self.rel_pos_h,
self.rel_pos_w)

attn = self.softmax(attn)

attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_w, N_w, C)
x = self.proj(x)
x = self.proj_drop(x)

x = x.view(-1, self.window_size[1], self.window_size[0], C)
x = window_reverse(x, self.window_size[0], Hp, Wp) # B H' W' C

if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()

x = x.view(B_, H * W, C)

return x


class Block(nn.Module):

def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
window_size=None,
attn_head_dim=None,
window=False):
super().__init__()
self.norm1 = norm_layer(dim)
if not window:
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
attn_head_dim=attn_head_dim)
else:
self.attn = WindowAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)

if init_values is not None:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None

def forward(self, x, H, W):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(
self.gamma_1 * self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x


class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (
img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches

self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]

x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)


class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""

def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_chans=3,
embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)

def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x


class Norm2d(nn.Module):

def __init__(self, embed_dim):
super().__init__()
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)

def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.ln(x)
x = x.permute(0, 3, 1, 2).contiguous()
return x


@BACKBONES.register_module()
class ViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""

def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=80,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=None,
init_values=None,
use_checkpoint=False,
use_abs_pos_emb=False,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
out_indices=[11],
interval=3,
pretrained=None):
super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models

if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone,
img_size=img_size,
in_chans=in_chans,
embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)

num_patches = self.patch_embed.num_patches

self.out_indices = out_indices

if use_abs_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim))
else:
self.pos_embed = None

self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=(14, 14) if
((i + 1) % interval != 0) else self.patch_embed.patch_shape,
window=((i + 1) % interval != 0)) for i in range(depth)
])

if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)

self.norm = norm_layer(embed_dim)

self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
Norm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)

self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2,
stride=2), )

self.fpn3 = nn.Identity()

self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)

self.apply(self._init_weights)
self.fix_init_weight()
self.pretrained = pretrained

def fix_init_weight(self):

def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))

for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.

Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
pretrained = pretrained or self.pretrained

def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
print(f'load from {pretrained}')
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')

def get_num_layers(self):
return len(self.blocks)

@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}

def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
batch_size, seq_len, _ = x.size()

if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)

features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x, Hp, Wp)

x = self.norm(x)
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp)

ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(ops)):
features.append(ops[i](xp))

return tuple(features)

def forward(self, x):

x = self.forward_features(x)

return x

+ 4
- 0
modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py View File

@@ -0,0 +1,4 @@
from .anchor_head import AnchorNHead
from .rpn_head import RPNNHead

__all__ = ['AnchorNHead', 'RPNNHead']

+ 48
- 0
modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py View File

@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet
from mmdet.models.builder import HEADS
from mmdet.models.dense_heads import AnchorHead


@HEADS.register_module()
class AnchorNHead(AnchorHead):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).

Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
anchor_generator (dict): Config dict for anchor generator
bbox_coder (dict): Config of bounding box coder.
reg_decoded_bbox (bool): If true, the regression loss would be
applied directly on decoded bounding boxes, converting both
the predicted boxes and regression targets to absolute
coordinates format. Default False. It should be `True` when
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605

def __init__(self,
num_classes,
in_channels,
feat_channels,
anchor_generator=None,
bbox_coder=None,
reg_decoded_bbox=False,
loss_cls=None,
loss_bbox=None,
train_cfg=None,
test_cfg=None,
norm_cfg=None,
init_cfg=None):
self.norm_cfg = norm_cfg
super(AnchorNHead,
self).__init__(num_classes, in_channels, feat_channels,
anchor_generator, bbox_coder, reg_decoded_bbox,
loss_cls, loss_bbox, train_cfg, test_cfg,
init_cfg)

+ 268
- 0
modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py View File

@@ -0,0 +1,268 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import batched_nms
from mmdet.models.builder import HEADS

from ..utils import ConvModule_Norm
from .anchor_head import AnchorNHead


@HEADS.register_module()
class RPNNHead(AnchorNHead):
"""RPN head.

Args:
in_channels (int): Number of channels in the input feature map.
init_cfg (dict or list[dict], optional): Initialization config dict.
num_convs (int): Number of convolution layers in the head. Default 1.
""" # noqa: W605

def __init__(self,
in_channels,
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
num_convs=1,
**kwargs):
self.num_convs = num_convs
super(RPNNHead, self).__init__(
1, in_channels, init_cfg=init_cfg, **kwargs)

def _init_layers(self):
"""Initialize layers of the head."""
if self.num_convs > 1:
rpn_convs = []
for i in range(self.num_convs):
if i == 0:
in_channels = self.in_channels
else:
in_channels = self.feat_channels
# use ``inplace=False`` to avoid error: one of the variables
# needed for gradient computation has been modified by an
# inplace operation.
rpn_convs.append(
ConvModule_Norm(
in_channels,
self.feat_channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
inplace=False))
self.rpn_conv = nn.Sequential(*rpn_convs)
else:
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_base_priors * self.cls_out_channels,
1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4,
1)

def forward_single(self, x):
"""Forward feature map of a single scale level."""
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
rpn_cls_score = self.rpn_cls(x)
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred

def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.

Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.

Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
losses = super(RPNNHead, self).loss(
cls_scores,
bbox_preds,
gt_bboxes,
None,
img_metas,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])

def _get_bboxes_single(self,
cls_score_list,
bbox_pred_list,
score_factor_list,
mlvl_anchors,
img_meta,
cfg,
rescale=False,
with_nms=True,
**kwargs):
"""Transform outputs of a single image into bbox predictions.

Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_anchors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has
shape (num_anchors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image. RPN head does not need this value.
mlvl_anchors (list[Tensor]): Anchors of all scale level
each item has shape (num_anchors, 4).
img_meta (dict): Image meta info.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.

Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1.
"""
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']

# bboxes from different level should be independent during NMS,
# level_ids are used as labels for batched NMS to separate them
level_ids = []
mlvl_scores = []
mlvl_bbox_preds = []
mlvl_valid_anchors = []
nms_pre = cfg.get('nms_pre', -1)
for level_idx in range(len(cls_score_list)):
rpn_cls_score = cls_score_list[level_idx]
rpn_bbox_pred = bbox_pred_list[level_idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = rpn_cls_score.softmax(dim=1)[:, 0]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)

anchors = mlvl_anchors[level_idx]
if 0 < nms_pre < scores.shape[0]:
# sort is faster than topk
# _, topk_inds = scores.topk(cfg.nms_pre)
ranked_scores, rank_inds = scores.sort(descending=True)
topk_inds = rank_inds[:nms_pre]
scores = ranked_scores[:nms_pre]
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]

mlvl_scores.append(scores)
mlvl_bbox_preds.append(rpn_bbox_pred)
mlvl_valid_anchors.append(anchors)
level_ids.append(
scores.new_full((scores.size(0), ),
level_idx,
dtype=torch.long))

return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds,
mlvl_valid_anchors, level_ids, cfg,
img_shape)

def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors,
level_ids, cfg, img_shape, **kwargs):
"""bbox post-processing method.

The boxes would be rescaled to the original image scale and do
the nms operation. Usually with_nms is False is used for aug test.

Args:
mlvl_scores (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_bboxes, num_class).
mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
levels of a single image, each item has shape (num_bboxes, 4).
mlvl_valid_anchors (list[Tensor]): Anchors of all scale level
each item has shape (num_bboxes, 4).
level_ids (list[Tensor]): Indexes from all scale levels of a
single image, each item has shape (num_bboxes, ).
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
img_shape (tuple(int)): Shape of current image.

Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1.
"""
scores = torch.cat(mlvl_scores)
anchors = torch.cat(mlvl_valid_anchors)
rpn_bbox_pred = torch.cat(mlvl_bboxes)
proposals = self.bbox_coder.decode(
anchors, rpn_bbox_pred, max_shape=img_shape)
ids = torch.cat(level_ids)

if cfg.min_bbox_size >= 0:
w = proposals[:, 2] - proposals[:, 0]
h = proposals[:, 3] - proposals[:, 1]
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
proposals = proposals[valid_mask]
scores = scores[valid_mask]
ids = ids[valid_mask]

if proposals.numel() > 0:
dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
else:
return proposals.new_zeros(0, 5)

return dets[:cfg.max_per_img]

def onnx_export(self, x, img_metas):
"""Test without augmentation.

Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.
Returns:
Tensor: dets of shape [N, num_det, 5].
"""
cls_scores, bbox_preds = self(x)

assert len(cls_scores) == len(bbox_preds)

batch_bboxes, batch_scores = super(RPNNHead, self).onnx_export(
cls_scores, bbox_preds, img_metas=img_metas, with_nms=False)
# Use ONNX::NonMaxSuppression in deployment
from mmdet.core.export import add_dummy_nms_for_onnx
cfg = copy.deepcopy(self.test_cfg)
score_threshold = cfg.nms.get('score_thr', 0.0)
nms_pre = cfg.get('deploy_nms_pre', -1)
# Different from the normal forward doing NMS level by level,
# we do NMS across all levels when exporting ONNX.
dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
cfg.max_per_img,
cfg.nms.iou_threshold,
score_threshold, nms_pre,
cfg.max_per_img)
return dets

+ 3
- 0
modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py View File

@@ -0,0 +1,3 @@
from .fpn import FPNF

__all__ = ['FPNF']

+ 207
- 0
modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py View File

@@ -0,0 +1,207 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule, auto_fp16
from mmdet.models.builder import NECKS

from ..utils import ConvModule_Norm


@NECKS.register_module()
class FPNF(BaseModule):
r"""Feature Pyramid Network.

This is an implementation of paper `Feature Pyramid Networks for Object
Detection <https://arxiv.org/abs/1612.03144>`_.

Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool | str): If bool, it decides whether to add conv
layers on top of the original feature maps. Default to False.
If True, it is equivalent to `add_extra_convs='on_input'`.
If str, it specifies the source feature map of the extra convs.
Only the following options are allowed

- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- 'on_lateral': Last feature map after lateral convs.
- 'on_output': The last output feature map after fpn convs.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule.
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(mode='nearest')`
init_cfg (dict or list[dict], optional): Initialization config dict.

Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 11, 340, 340])
outputs[1].shape = torch.Size([1, 11, 170, 170])
outputs[2].shape = torch.Size([1, 11, 84, 84])
outputs[3].shape = torch.Size([1, 11, 43, 43])
"""

def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
use_residual=True,
upsample_cfg=dict(mode='nearest'),
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPNF, self).__init__(init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.fp16_enabled = False
self.upsample_cfg = upsample_cfg.copy()
self.use_residual = use_residual

if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
self.add_extra_convs = 'on_input'

self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()

for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule_Norm(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule_Norm(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)

self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule_Norm(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)

@auto_fp16()
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.in_channels)

# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]

# build top-down path
used_backbone_levels = len(laterals)
if self.use_residual:
for i in range(used_backbone_levels - 1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i], size=prev_shape, **self.upsample_cfg)

# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))

return tuple(outs)

+ 8
- 0
modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py View File

@@ -0,0 +1,8 @@
from .bbox_heads import (ConvFCBBoxNHead, Shared2FCBBoxNHead,
Shared4Conv1FCBBoxNHead)
from .mask_heads import FCNMaskNHead

__all__ = [
'ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead',
'FCNMaskNHead'
]

+ 4
- 0
modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py View File

@@ -0,0 +1,4 @@
from .convfc_bbox_head import (ConvFCBBoxNHead, Shared2FCBBoxNHead,
Shared4Conv1FCBBoxNHead)

__all__ = ['ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead']

+ 229
- 0
modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py View File

@@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet
import torch.nn as nn
from mmdet.models.builder import HEADS
from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead
from mmdet.models.utils import build_linear_layer

from ...utils import ConvModule_Norm


@HEADS.register_module()
class ConvFCBBoxNHead(BBoxHead):
r"""More general bbox head, with shared conv and fc layers and two optional
separated branches.

.. code-block:: none

/-> cls convs -> cls fcs -> cls
shared convs -> shared fcs
\-> reg convs -> reg fcs -> reg
""" # noqa: W605

def __init__(self,
num_shared_convs=0,
num_shared_fcs=0,
num_cls_convs=0,
num_cls_fcs=0,
num_reg_convs=0,
num_reg_fcs=0,
conv_out_channels=256,
fc_out_channels=1024,
conv_cfg=None,
norm_cfg=None,
init_cfg=None,
*args,
**kwargs):
super(ConvFCBBoxNHead, self).__init__(
*args, init_cfg=init_cfg, **kwargs)
assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs
+ num_reg_convs + num_reg_fcs > 0)
if num_cls_convs > 0 or num_reg_convs > 0:
assert num_shared_fcs == 0
if not self.with_cls:
assert num_cls_convs == 0 and num_cls_fcs == 0
if not self.with_reg:
assert num_reg_convs == 0 and num_reg_fcs == 0
self.num_shared_convs = num_shared_convs
self.num_shared_fcs = num_shared_fcs
self.num_cls_convs = num_cls_convs
self.num_cls_fcs = num_cls_fcs
self.num_reg_convs = num_reg_convs
self.num_reg_fcs = num_reg_fcs
self.conv_out_channels = conv_out_channels
self.fc_out_channels = fc_out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg

# add shared convs and fcs
self.shared_convs, self.shared_fcs, last_layer_dim = \
self._add_conv_fc_branch(
self.num_shared_convs, self.num_shared_fcs, self.in_channels,
True)
self.shared_out_channels = last_layer_dim

# add cls specific branch
self.cls_convs, self.cls_fcs, self.cls_last_dim = \
self._add_conv_fc_branch(
self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)

# add reg specific branch
self.reg_convs, self.reg_fcs, self.reg_last_dim = \
self._add_conv_fc_branch(
self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)

if self.num_shared_fcs == 0 and not self.with_avg_pool:
if self.num_cls_fcs == 0:
self.cls_last_dim *= self.roi_feat_area
if self.num_reg_fcs == 0:
self.reg_last_dim *= self.roi_feat_area

self.relu = nn.ReLU(inplace=True)
# reconstruct fc_cls and fc_reg since input channels are changed
if self.with_cls:
if self.custom_cls_channels:
cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
else:
cls_channels = self.num_classes + 1
self.fc_cls = build_linear_layer(
self.cls_predictor_cfg,
in_features=self.cls_last_dim,
out_features=cls_channels)
if self.with_reg:
out_dim_reg = (4 if self.reg_class_agnostic else 4
* self.num_classes)
self.fc_reg = build_linear_layer(
self.reg_predictor_cfg,
in_features=self.reg_last_dim,
out_features=out_dim_reg)

if init_cfg is None:
# when init_cfg is None,
# It has been set to
# [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))],
# [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))]
# after `super(ConvFCBBoxHead, self).__init__()`
# we only need to append additional configuration
# for `shared_fcs`, `cls_fcs` and `reg_fcs`
self.init_cfg += [
dict(
type='Xavier',
override=[
dict(name='shared_fcs'),
dict(name='cls_fcs'),
dict(name='reg_fcs')
])
]

def _add_conv_fc_branch(self,
num_branch_convs,
num_branch_fcs,
in_channels,
is_shared=False):
"""Add shared or separable branch.

convs -> avg pool (optional) -> fcs
"""
last_layer_dim = in_channels
# add branch specific conv layers
branch_convs = nn.ModuleList()
if num_branch_convs > 0:
for i in range(num_branch_convs):
conv_in_channels = (
last_layer_dim if i == 0 else self.conv_out_channels)
branch_convs.append(
ConvModule_Norm(
conv_in_channels,
self.conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
last_layer_dim = self.conv_out_channels
# add branch specific fc layers
branch_fcs = nn.ModuleList()
if num_branch_fcs > 0:
# for shared branch, only consider self.with_avg_pool
# for separated branches, also consider self.num_shared_fcs
if (is_shared
or self.num_shared_fcs == 0) and not self.with_avg_pool:
last_layer_dim *= self.roi_feat_area
for i in range(num_branch_fcs):
fc_in_channels = (
last_layer_dim if i == 0 else self.fc_out_channels)
branch_fcs.append(
nn.Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels
return branch_convs, branch_fcs, last_layer_dim

def forward(self, x):
# shared part
if self.num_shared_convs > 0:
for conv in self.shared_convs:
x = conv(x)

if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)

x = x.flatten(1)

for fc in self.shared_fcs:
x = self.relu(fc(x))
# separate branches
x_cls = x
x_reg = x

for conv in self.cls_convs:
x_cls = conv(x_cls)
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.flatten(1)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))

for conv in self.reg_convs:
x_reg = conv(x_reg)
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.flatten(1)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))

cls_score = self.fc_cls(x_cls) if self.with_cls else None
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
return cls_score, bbox_pred


@HEADS.register_module()
class Shared2FCBBoxNHead(ConvFCBBoxNHead):

def __init__(self, fc_out_channels=1024, *args, **kwargs):
super(Shared2FCBBoxNHead, self).__init__(
num_shared_convs=0,
num_shared_fcs=2,
num_cls_convs=0,
num_cls_fcs=0,
num_reg_convs=0,
num_reg_fcs=0,
fc_out_channels=fc_out_channels,
*args,
**kwargs)


@HEADS.register_module()
class Shared4Conv1FCBBoxNHead(ConvFCBBoxNHead):

def __init__(self, fc_out_channels=1024, *args, **kwargs):
super(Shared4Conv1FCBBoxNHead, self).__init__(
num_shared_convs=4,
num_shared_fcs=1,
num_cls_convs=0,
num_cls_fcs=0,
num_reg_convs=0,
num_reg_fcs=0,
fc_out_channels=fc_out_channels,
*args,
**kwargs)

+ 3
- 0
modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py View File

@@ -0,0 +1,3 @@
from .fcn_mask_head import FCNMaskNHead

__all__ = ['FCNMaskNHead']

+ 414
- 0
modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py View File

@@ -0,0 +1,414 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet
from warnings import warn

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer
from mmcv.ops.carafe import CARAFEPack
from mmcv.runner import BaseModule, ModuleList, auto_fp16, force_fp32
from mmdet.core import mask_target
from mmdet.models.builder import HEADS, build_loss
from torch.nn.modules.utils import _pair

from ...utils import ConvModule_Norm

BYTES_PER_FLOAT = 4
# TODO: This memory limit may be too much or too little. It would be better to
# determine it based on available resources.
GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit


@HEADS.register_module()
class FCNMaskNHead(BaseModule):

def __init__(self,
num_convs=4,
roi_feat_size=14,
in_channels=256,
conv_kernel_size=3,
conv_out_channels=256,
num_classes=80,
class_agnostic=False,
upsample_cfg=dict(type='deconv', scale_factor=2),
conv_cfg=None,
norm_cfg=None,
predictor_cfg=dict(type='Conv'),
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(FCNMaskNHead, self).__init__(init_cfg)
self.upsample_cfg = upsample_cfg.copy()
if self.upsample_cfg['type'] not in [
None, 'deconv', 'nearest', 'bilinear', 'carafe'
]:
raise ValueError(
f'Invalid upsample method {self.upsample_cfg["type"]}, '
'accepted methods are "deconv", "nearest", "bilinear", '
'"carafe"')
self.num_convs = num_convs
# WARN: roi_feat_size is reserved and not used
self.roi_feat_size = _pair(roi_feat_size)
self.in_channels = in_channels
self.conv_kernel_size = conv_kernel_size
self.conv_out_channels = conv_out_channels
self.upsample_method = self.upsample_cfg.get('type')
self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
self.num_classes = num_classes
self.class_agnostic = class_agnostic
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.predictor_cfg = predictor_cfg
self.fp16_enabled = False
self.loss_mask = build_loss(loss_mask)

self.convs = ModuleList()
for i in range(self.num_convs):
in_channels = (
self.in_channels if i == 0 else self.conv_out_channels)
padding = (self.conv_kernel_size - 1) // 2
self.convs.append(
ConvModule_Norm(
in_channels,
self.conv_out_channels,
self.conv_kernel_size,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
upsample_in_channels = (
self.conv_out_channels if self.num_convs > 0 else in_channels)
upsample_cfg_ = self.upsample_cfg.copy()
if self.upsample_method is None:
self.upsample = None
elif self.upsample_method == 'deconv':
upsample_cfg_.update(
in_channels=upsample_in_channels,
out_channels=self.conv_out_channels,
kernel_size=self.scale_factor,
stride=self.scale_factor)
self.upsample = build_upsample_layer(upsample_cfg_)
elif self.upsample_method == 'carafe':
upsample_cfg_.update(
channels=upsample_in_channels, scale_factor=self.scale_factor)
self.upsample = build_upsample_layer(upsample_cfg_)
else:
# suppress warnings
align_corners = (None
if self.upsample_method == 'nearest' else False)
upsample_cfg_.update(
scale_factor=self.scale_factor,
mode=self.upsample_method,
align_corners=align_corners)
self.upsample = build_upsample_layer(upsample_cfg_)

out_channels = 1 if self.class_agnostic else self.num_classes
logits_in_channel = (
self.conv_out_channels
if self.upsample_method == 'deconv' else upsample_in_channels)
self.conv_logits = build_conv_layer(self.predictor_cfg,
logits_in_channel, out_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.debug_imgs = None

def init_weights(self):
super(FCNMaskNHead, self).init_weights()
for m in [self.upsample, self.conv_logits]:
if m is None:
continue
elif isinstance(m, CARAFEPack):
m.init_weights()
elif hasattr(m, 'weight') and hasattr(m, 'bias'):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0)

@auto_fp16()
def forward(self, x):
for conv in self.convs:
x = conv(x)
if self.upsample is not None:
x = self.upsample(x)
if self.upsample_method == 'deconv':
x = self.relu(x)
mask_pred = self.conv_logits(x)
return mask_pred

def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg):
pos_proposals = [res.pos_bboxes for res in sampling_results]
pos_assigned_gt_inds = [
res.pos_assigned_gt_inds for res in sampling_results
]
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, rcnn_train_cfg)
return mask_targets

@force_fp32(apply_to=('mask_pred', ))
def loss(self, mask_pred, mask_targets, labels):
"""
Example:
>>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
>>> N = 7 # N = number of extracted ROIs
>>> C, H, W = 11, 32, 32
>>> # Create example instance of FCN Mask Head.
>>> # There are lots of variations depending on the configuration
>>> self = FCNMaskHead(num_classes=C, num_convs=1)
>>> inputs = torch.rand(N, self.in_channels, H, W)
>>> mask_pred = self.forward(inputs)
>>> sf = self.scale_factor
>>> labels = torch.randint(0, C, size=(N,))
>>> # With the default properties the mask targets should indicate
>>> # a (potentially soft) single-class label
>>> mask_targets = torch.rand(N, H * sf, W * sf)
>>> loss = self.loss(mask_pred, mask_targets, labels)
>>> print('loss = {!r}'.format(loss))
"""
loss = dict()
if mask_pred.size(0) == 0:
loss_mask = mask_pred.sum()
else:
if self.class_agnostic:
loss_mask = self.loss_mask(mask_pred, mask_targets,
torch.zeros_like(labels))
else:
loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
loss['loss_mask'] = loss_mask
return loss

def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale):
"""Get segmentation masks from mask_pred and bboxes.

Args:
mask_pred (Tensor or ndarray): shape (n, #class, h, w).
For single-scale testing, mask_pred is the direct output of
model, whose type is Tensor, while for multi-scale testing,
it will be converted to numpy array outside of this method.
det_bboxes (Tensor): shape (n, 4/5)
det_labels (Tensor): shape (n, )
rcnn_test_cfg (dict): rcnn testing config
ori_shape (Tuple): original image height and width, shape (2,)
scale_factor(ndarray | Tensor): If ``rescale is True``, box
coordinates are divided by this scale factor to fit
``ori_shape``.
rescale (bool): If True, the resulting masks will be rescaled to
``ori_shape``.

Returns:
list[list]: encoded masks. The c-th item in the outer list
corresponds to the c-th class. Given the c-th outer list, the
i-th item in that inner list is the mask for the i-th box with
class label c.

Example:
>>> import mmcv
>>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
>>> N = 7 # N = number of extracted ROIs
>>> C, H, W = 11, 32, 32
>>> # Create example instance of FCN Mask Head.
>>> self = FCNMaskHead(num_classes=C, num_convs=0)
>>> inputs = torch.rand(N, self.in_channels, H, W)
>>> mask_pred = self.forward(inputs)
>>> # Each input is associated with some bounding box
>>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
>>> det_labels = torch.randint(0, C, size=(N,))
>>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
>>> ori_shape = (H * 4, W * 4)
>>> scale_factor = torch.FloatTensor((1, 1))
>>> rescale = False
>>> # Encoded masks are a list for each category.
>>> encoded_masks = self.get_seg_masks(
>>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
>>> scale_factor, rescale
>>> )
>>> assert len(encoded_masks) == C
>>> assert sum(list(map(len, encoded_masks))) == N
"""
if isinstance(mask_pred, torch.Tensor):
mask_pred = mask_pred.sigmoid()
else:
# In AugTest, has been activated before
mask_pred = det_bboxes.new_tensor(mask_pred)

device = mask_pred.device
cls_segms = [[] for _ in range(self.num_classes)
] # BG is not included in num_classes
bboxes = det_bboxes[:, :4]
labels = det_labels

# In most cases, scale_factor should have been
# converted to Tensor when rescale the bbox
if not isinstance(scale_factor, torch.Tensor):
if isinstance(scale_factor, float):
scale_factor = np.array([scale_factor] * 4)
warn('Scale_factor should be a Tensor or ndarray '
'with shape (4,), float would be deprecated. ')
assert isinstance(scale_factor, np.ndarray)
scale_factor = torch.Tensor(scale_factor)

if rescale:
img_h, img_w = ori_shape[:2]
bboxes = bboxes / scale_factor.to(bboxes)
else:
w_scale, h_scale = scale_factor[0], scale_factor[1]
img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32)
img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32)

N = len(mask_pred)
# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == 'cpu':
# CPU is most efficient when they are pasted one by one with
# skip_empty=True, so that it performs minimal number of
# operations.
num_chunks = N
else:
# GPU benefits from parallelism for larger chunks,
# but may have memory issue
# the types of img_w and img_h are np.int32,
# when the image resolution is large,
# the calculation of num_chunks will overflow.
# so we need to change the types of img_w and img_h to int.
# See https://github.com/open-mmlab/mmdetection/pull/5191
num_chunks = int(
np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT
/ GPU_MEM_LIMIT))
# assert (num_chunks <= N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
assert num_chunks <= N, 'Default GPU_MEM_LIMIT is too small; try increasing it'
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

threshold = rcnn_test_cfg.mask_thr_binary
im_mask = torch.zeros(
N,
img_h,
img_w,
device=device,
dtype=torch.bool if threshold >= 0 else torch.uint8)

if not self.class_agnostic:
mask_pred = mask_pred[range(N), labels][:, None]

for inds in chunks:
masks_chunk, spatial_inds = _do_paste_mask(
mask_pred[inds],
bboxes[inds],
img_h,
img_w,
skip_empty=device.type == 'cpu')

if threshold >= 0:
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
else:
# for visualization and debugging
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

im_mask[(inds, ) + spatial_inds] = masks_chunk

for i in range(N):
cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
return cls_segms

def onnx_export(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_shape, **kwargs):
"""Get segmentation masks from mask_pred and bboxes.

Args:
mask_pred (Tensor): shape (n, #class, h, w).
det_bboxes (Tensor): shape (n, 4/5)
det_labels (Tensor): shape (n, )
rcnn_test_cfg (dict): rcnn testing config
ori_shape (Tuple): original image height and width, shape (2,)

Returns:
Tensor: a mask of shape (N, img_h, img_w).
"""

mask_pred = mask_pred.sigmoid()
bboxes = det_bboxes[:, :4]
labels = det_labels
# No need to consider rescale and scale_factor while exporting to ONNX
img_h, img_w = ori_shape[:2]
threshold = rcnn_test_cfg.mask_thr_binary
if not self.class_agnostic:
box_inds = torch.arange(mask_pred.shape[0])
mask_pred = mask_pred[box_inds, labels][:, None]
masks, _ = _do_paste_mask(
mask_pred, bboxes, img_h, img_w, skip_empty=False)
if threshold >= 0:
# should convert to float to avoid problems in TRT
masks = (masks >= threshold).to(dtype=torch.float)
return masks


def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
"""Paste instance masks according to boxes.

This implementation is modified from
https://github.com/facebookresearch/detectron2/

Args:
masks (Tensor): N, 1, H, W
boxes (Tensor): N, 4
img_h (int): Height of the image to be pasted.
img_w (int): Width of the image to be pasted.
skip_empty (bool): Only paste masks within the region that
tightly bound all boxes, and returns the results this region only.
An important optimization for CPU.

Returns:
tuple: (Tensor, tuple). The first item is mask tensor, the second one
is the slice object.
If skip_empty == False, the whole image will be pasted. It will
return a mask of shape (N, img_h, img_w) and an empty tuple.
If skip_empty == True, only area around the mask will be pasted.
A mask of shape (N, h', w') and its start and end coordinates
in the original image will be returned.
"""
# On GPU, paste all masks together (up to chunk size)
# by using the entire image to sample the masks
# Compared to pasting them one by one,
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty:
x0_int, y0_int = torch.clamp(
boxes.min(dim=0).values.floor()[:2] - 1,
min=0).to(dtype=torch.int32)
x1_int = torch.clamp(
boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
y1_int = torch.clamp(
boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
else:
x0_int, y0_int = 0, 0
x1_int, y1_int = img_w, img_h
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1

N = masks.shape[0]

img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
# IsInf op is not supported with ONNX<=1.7.0
if not torch.onnx.is_in_onnx_export():
if torch.isinf(img_x).any():
inds = torch.where(torch.isinf(img_x))
img_x[inds] = 0
if torch.isinf(img_y).any():
inds = torch.where(torch.isinf(img_y))
img_y[inds] = 0

gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

img_masks = F.grid_sample(
masks.to(dtype=torch.float32), grid, align_corners=False)

if skip_empty:
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
else:
return img_masks[:, 0], ()

+ 4
- 0
modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py View File

@@ -0,0 +1,4 @@
from .checkpoint import load_checkpoint
from .convModule_norm import ConvModule_Norm

__all__ = ['load_checkpoint', 'ConvModule_Norm']

+ 558
- 0
modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py View File

@@ -0,0 +1,558 @@
# Copyright (c) Open-MMLab. All rights reserved.
# Implementation adopted from ViTAE-Transformer, source code avaiable via https://github.com/ViTAE-Transformer/ViTDet
import io
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory

import mmcv
import torch
import torchvision
from mmcv.fileio import FileClient
from mmcv.fileio import load as load_file
from mmcv.parallel import is_module_wrapper
from mmcv.runner import get_dist_info
from torch.nn import functional as F
from torch.optim import Optimizer
from torch.utils import model_zoo


def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.

This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.

Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []

metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')

load(module)
load = None # break load->load reference cycle
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]

if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
print('finish load')


def load_url_dist(url, model_dir=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
return checkpoint


def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint


def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint


def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls


def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)

return default_urls


def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls


def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls


def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint


def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.

Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint


def load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None,
load_ema=True):
"""Load checkpoint from a file or URI.

Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.

Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if load_ema and 'state_dict_ema' in checkpoint:
state_dict = checkpoint['state_dict_ema']
# logger.info(f'loading from state_dict_ema')
logger.info('loading from state_dict_ema')
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
# logger.info(f'loading from state_dict')
logger.info('loading from state_dict')
elif 'model' in checkpoint:
state_dict = checkpoint['model']
# logger.info(f'loading from model')
logger.info('loading from model')
print('loading from model')
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}

# for MoBY, load model of online branch
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {
k.replace('encoder.', ''): v
for k, v in state_dict.items() if k.startswith('encoder.')
}

# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = model.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2)

all_keys = list(state_dict.keys())
for key in all_keys:
if 'relative_position_index' in key:
state_dict.pop(key)

if 'relative_position_bias_table' in key:
state_dict.pop(key)

if '.q_bias' in key:
q_bias = state_dict[key]
v_bias = state_dict[key.replace('q_bias', 'v_bias')]
qkv_bias = torch.cat([q_bias, torch.zeros_like(q_bias), v_bias], 0)
state_dict[key.replace('q_bias', 'qkv.bias')] = qkv_bias

if '.v.bias' in key:
continue

all_keys = list(state_dict.keys())
new_state_dict = {}
for key in all_keys:
if 'qkv.bias' in key:
value = state_dict[key]
dim = value.shape[0]
selected_dim = (dim * 2) // 3
new_state_dict[key.replace(
'qkv.bias', 'pos_bias')] = state_dict[key][:selected_dim]

# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys() if 'relative_position_bias_table' in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
if table_key not in model.state_dict().keys():
logger.warning(
'relative_position_bias_table exits in pretrained model but not in current one, pass'
)
continue
table_current = model.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
else:
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0)
rank, _ = get_dist_info()
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
H, W = model.patch_embed.patch_shape
num_patches = model.patch_embed.num_patches
num_extra_tokens = 1
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
if rank == 0:
print('Position interpolate from %dx%d to %dx%d' %
(orig_size, orig_size, H, W))
# extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(
0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
# new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed

# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint


def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.

Args:
state_dict (OrderedDict): Model weights on GPU.

Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu


def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.

This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.

Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()


def get_state_dict(module, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.

This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).

Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.

Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module

# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination


def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.

The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.

Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())

if is_module_wrapper(model):
model = model.module

if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()

if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()

+ 30
- 0
modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py View File

@@ -0,0 +1,30 @@
# Implementation adopted from ViTAE-Transformer, source code avaiable via https://github.com/ViTAE-Transformer/ViTDet

from mmcv.cnn import ConvModule


class ConvModule_Norm(ConvModule):

def __init__(self, in_channels, out_channels, kernel, **kwargs):
super().__init__(in_channels, out_channels, kernel, **kwargs)

self.normType = kwargs.get('norm_cfg', {'type': ''})
if self.normType is not None:
self.normType = self.normType['type']

def forward(self, x, activate=True, norm=True):
for layer in self.order:
if layer == 'conv':
if self.with_explicit_padding:
x = self.padding_layer(x)
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
if 'LN' in self.normType:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2).contiguous()
else:
x = self.norm(x)
elif layer == 'act' and activate and self.with_activation:
x = self.activate(x)
return x

+ 5
- 2
modelscope/pipelines/base.py View File

@@ -62,6 +62,7 @@ class Pipeline(ABC):
model: Union[InputModel, List[InputModel]] = None,
preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
device: str = 'gpu',
auto_collate=True,
**kwargs):
""" Base class for pipeline.

@@ -74,6 +75,7 @@ class Pipeline(ABC):
model: (list of) Model name or model object
preprocessor: (list of) Preprocessor object
device (str): gpu device or cpu device to use
auto_collate (bool): automatically to convert data to tensor or not.
"""
if config_file is not None:
self.cfg = Config.from_file(config_file)
@@ -98,6 +100,7 @@ class Pipeline(ABC):
self.device = create_device(self.device_name == 'cpu')
self._model_prepare = False
self._model_prepare_lock = Lock()
self._auto_collate = auto_collate

def prepare_model(self):
self._model_prepare_lock.acquire(timeout=600)
@@ -252,7 +255,7 @@ class Pipeline(ABC):
return self._collate_fn(torch.from_numpy(data))
elif isinstance(data, torch.Tensor):
return data.to(self.device)
elif isinstance(data, (str, int, float, bool)):
elif isinstance(data, (str, int, float, bool, type(None))):
return data
elif isinstance(data, InputFeatures):
return data
@@ -270,7 +273,7 @@ class Pipeline(ABC):

out = self.preprocess(input, **preprocess_params)
with self.place_device():
if self.framework == Frameworks.torch:
if self.framework == Frameworks.torch and self._auto_collate:
with torch.no_grad():
out = self._collate_fn(out)
out = self.forward(out, **forward_params)


+ 4
- 0
modelscope/pipelines/builder.py View File

@@ -35,6 +35,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
), # TODO: revise back after passing the pr
Tasks.image_matting: (Pipelines.image_matting,
'damo/cv_unet_image-matting'),
Tasks.human_detection: (Pipelines.human_detection,
'damo/cv_resnet18_human-detection'),
Tasks.object_detection: (Pipelines.object_detection,
'damo/cv_vit_object-detection_coco'),
Tasks.image_denoise: (Pipelines.image_denoise,
'damo/cv_nafnet_image-denoise_sidd'),
Tasks.text_classification: (Pipelines.sentiment_analysis,


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recognition_pipeline import AnimalRecognitionPipeline
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .object_detection_pipeline import ObjectDetectionPipeline
from .face_detection_pipeline import FaceDetectionPipeline
from .face_recognition_pipeline import FaceRecognitionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
@@ -30,6 +31,7 @@ else:
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'object_detection_pipeline': ['ObjectDetectionPipeline'],
'face_detection_pipeline': ['FaceDetectionPipeline'],
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'face_recognition_pipeline': ['FaceRecognitionPipeline'],


+ 51
- 0
modelscope/pipelines/cv/object_detection_pipeline.py View File

@@ -0,0 +1,51 @@
from typing import Any, Dict

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger


@PIPELINES.register_module(
Tasks.human_detection, module_name=Pipelines.human_detection)
@PIPELINES.register_module(
Tasks.object_detection, module_name=Pipelines.object_detection)
class ObjectDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
model: model id on modelscope hub.
"""
super().__init__(model=model, auto_collate=False, **kwargs)

def preprocess(self, input: Input) -> Dict[str, Any]:

img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float)
img = self.model.preprocess(img)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

outputs = self.model.inference(input['img'])
result = {'data': outputs}
return result

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

bboxes, scores, labels = self.model.postprocess(inputs['data'])
if bboxes is None:
return None
outputs = {
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
OutputKeys.BOXES: bboxes
}

return outputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -20,6 +20,7 @@ class CVTasks(object):
image_classification = 'image-classification'
image_tagging = 'image-tagging'
object_detection = 'object-detection'
human_detection = 'human-detection'
image_segmentation = 'image-segmentation'
image_editing = 'image-editing'
image_generation = 'image-generation'


+ 1
- 0
requirements/cv.txt View File

@@ -1,4 +1,5 @@
decord>=0.6.0
easydict
tf_slim
timm
torchvision

+ 56
- 0
tests/pipelines/test_object_detection.py View File

@@ -0,0 +1,56 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level


class ObjectDetectionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_object_detection(self):
input_location = 'data/test/images/image_detection.jpg'
model_id = 'damo/cv_vit_object-detection_coco'
object_detect = pipeline(Tasks.object_detection, model=model_id)
result = object_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_object_detection_with_default_task(self):
input_location = 'data/test/images/image_detection.jpg'
object_detect = pipeline(Tasks.object_detection)
result = object_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_human_detection(self):
input_location = 'data/test/images/image_detection.jpg'
model_id = 'damo/cv_resnet18_human-detection'
human_detect = pipeline(Tasks.human_detection, model=model_id)
result = human_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_human_detection_with_default_task(self):
input_location = 'data/test/images/image_detection.jpg'
human_detect = pipeline(Tasks.human_detection)
result = human_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save