Browse Source

[to #42322933]support video-single-object-tracking

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9781254

    * support single object tracking
master
lanjinpeng.ljp yingda.chen 3 years ago
parent
commit
c6db966a99
27 changed files with 1510 additions and 1 deletions
  1. +1
    -0
      .gitattributes
  2. +3
    -0
      data/test/videos/dog.avi
  3. +1
    -0
      modelscope/metainfo.py
  4. +1
    -1
      modelscope/models/cv/__init__.py
  5. +0
    -0
      modelscope/models/cv/video_single_object_tracking/__init__.py
  6. +0
    -0
      modelscope/models/cv/video_single_object_tracking/config/__init__.py
  7. +39
    -0
      modelscope/models/cv/video_single_object_tracking/config/ostrack.py
  8. +0
    -0
      modelscope/models/cv/video_single_object_tracking/models/__init__.py
  9. +0
    -0
      modelscope/models/cv/video_single_object_tracking/models/layers/__init__.py
  10. +54
    -0
      modelscope/models/cv/video_single_object_tracking/models/layers/attn.py
  11. +129
    -0
      modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py
  12. +141
    -0
      modelscope/models/cv/video_single_object_tracking/models/layers/head.py
  13. +37
    -0
      modelscope/models/cv/video_single_object_tracking/models/layers/patch_embed.py
  14. +0
    -0
      modelscope/models/cv/video_single_object_tracking/models/ostrack/__init__.py
  15. +93
    -0
      modelscope/models/cv/video_single_object_tracking/models/ostrack/base_backbone.py
  16. +109
    -0
      modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py
  17. +24
    -0
      modelscope/models/cv/video_single_object_tracking/models/ostrack/utils.py
  18. +343
    -0
      modelscope/models/cv/video_single_object_tracking/models/ostrack/vit_ce.py
  19. +0
    -0
      modelscope/models/cv/video_single_object_tracking/tracker/__init__.py
  20. +139
    -0
      modelscope/models/cv/video_single_object_tracking/tracker/ostrack.py
  21. +0
    -0
      modelscope/models/cv/video_single_object_tracking/utils/__init__.py
  22. +261
    -0
      modelscope/models/cv/video_single_object_tracking/utils/utils.py
  23. +10
    -0
      modelscope/outputs.py
  24. +3
    -0
      modelscope/pipelines/builder.py
  25. +80
    -0
      modelscope/pipelines/cv/video_single_object_tracking_pipeline.py
  26. +3
    -0
      modelscope/utils/constant.py
  27. +39
    -0
      tests/pipelines/test_video_single_object_tracking.py

+ 1
- 0
.gitattributes View File

@@ -4,3 +4,4 @@
*.wav filter=lfs diff=lfs merge=lfs -text
*.JPEG filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.avi filter=lfs diff=lfs merge=lfs -text

+ 3
- 0
data/test/videos/dog.avi View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:469090fb217a34a2c096cfd42c251da69dca9fcd1a3c1faae7d29183c1816c14
size 12834294

+ 1
- 0
modelscope/metainfo.py View File

@@ -111,6 +111,7 @@ class Pipelines(object):
skin_retouching = 'unet-skin-retouching'
tinynas_classification = 'tinynas-classification'
crowd_counting = 'hrnet-crowd-counting'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'

# nlp tasks
sentence_similarity = 'sentence-similarity'


+ 1
- 1
modelscope/models/cv/__init__.py View File

@@ -6,4 +6,4 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_portrait_enhancement, image_to_image_generation,
image_to_image_translation, object_detection,
product_retrieval_embedding, salient_detection,
super_resolution, virual_tryon)
super_resolution, video_single_object_tracking, virual_tryon)

+ 0
- 0
modelscope/models/cv/video_single_object_tracking/__init__.py View File


+ 0
- 0
modelscope/models/cv/video_single_object_tracking/config/__init__.py View File


+ 39
- 0
modelscope/models/cv/video_single_object_tracking/config/ostrack.py View File

@@ -0,0 +1,39 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
from easydict import EasyDict as edict

cfg = edict()

# MODEL
cfg.MODEL = edict()

# MODEL.BACKBONE
cfg.MODEL.BACKBONE = edict()
cfg.MODEL.BACKBONE.TYPE = 'vit_base_patch16_224_ce'
cfg.MODEL.BACKBONE.STRIDE = 16
cfg.MODEL.BACKBONE.CAT_MODE = 'direct'
cfg.MODEL.BACKBONE.DROP_PATH_RATE = 0.1
cfg.MODEL.BACKBONE.CE_LOC = [3, 6, 9]
cfg.MODEL.BACKBONE.CE_KEEP_RATIO = [0.7, 0.7, 0.7]
cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'CTR_POINT'

# MODEL.HEAD
cfg.MODEL.HEAD = edict()
cfg.MODEL.HEAD.TYPE = 'CENTER'
cfg.MODEL.HEAD.NUM_CHANNELS = 256

# DATA
cfg.DATA = edict()
cfg.DATA.MEAN = [0.485, 0.456, 0.406]
cfg.DATA.STD = [0.229, 0.224, 0.225]
cfg.DATA.SEARCH = edict()
cfg.DATA.SEARCH.SIZE = 384
cfg.DATA.TEMPLATE = edict()
cfg.DATA.TEMPLATE.SIZE = 192

# TEST
cfg.TEST = edict()
cfg.TEST.TEMPLATE_FACTOR = 2.0
cfg.TEST.TEMPLATE_SIZE = 192
cfg.TEST.SEARCH_FACTOR = 5.0
cfg.TEST.SEARCH_SIZE = 384

+ 0
- 0
modelscope/models/cv/video_single_object_tracking/models/__init__.py View File


+ 0
- 0
modelscope/models/cv/video_single_object_tracking/models/layers/__init__.py View File


+ 54
- 0
modelscope/models/cv/video_single_object_tracking/models/layers/attn.py View File

@@ -0,0 +1,54 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch.nn as nn


class Attention(nn.Module):

def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
rpe=False,
z_size=7,
x_size=14):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, mask=None, return_attention=False):
# x: B, N, C
# mask: [B, N, ] torch.bool
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(
0) # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale

if mask is not None:
attn = attn.masked_fill(
mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)

attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

if return_attention:
return x, attn
else:
return x

+ 129
- 0
modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py View File

@@ -0,0 +1,129 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import math

import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp

from .attn import Attention


def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor,
lens_t: int, keep_ratio: float,
global_index: torch.Tensor,
box_mask_z: torch.Tensor):
"""
Eliminate potential background candidates for computation reduction and noise cancellation.
Args:
attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights
tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens
lens_t (int): length of template
keep_ratio (float): keep ratio of search region tokens (candidates)
global_index (torch.Tensor): global index of search region tokens
box_mask_z (torch.Tensor): template mask used to accumulate attention weights

Returns:
tokens_new (torch.Tensor): tokens after candidate elimination
keep_index (torch.Tensor): indices of kept search region tokens
removed_index (torch.Tensor): indices of removed search region tokens
"""
lens_s = attn.shape[-1] - lens_t
bs, hn, _, _ = attn.shape

lens_keep = math.ceil(keep_ratio * lens_s)
if lens_keep == lens_s:
return tokens, global_index, None

attn_t = attn[:, :, :lens_t, lens_t:]

if box_mask_z is not None:
box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand(
-1, attn_t.shape[1], -1, attn_t.shape[-1])
attn_t = attn_t[box_mask_z]
attn_t = attn_t.view(bs, hn, -1, lens_s)
attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s
else:
attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s

# use sort instead of topk, due to the speed issue
# https://github.com/pytorch/pytorch/issues/22812
sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True)

_, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep]
_, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:]
keep_index = global_index.gather(dim=1, index=topk_idx)
removed_index = global_index.gather(dim=1, index=non_topk_idx)

# separate template and search tokens
tokens_t = tokens[:, :lens_t]
tokens_s = tokens[:, lens_t:]

# obtain the attentive and inattentive tokens
B, L, C = tokens_s.shape
attentive_tokens = tokens_s.gather(
dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C))

# concatenate these tokens
tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1)

return tokens_new, keep_index, removed_index


class CEBlock(nn.Module):

def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
keep_ratio_search=1.0,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)

self.keep_ratio_search = keep_ratio_search

def forward(self,
x,
global_index_template,
global_index_search,
mask=None,
ce_template_mask=None,
keep_ratio_search=None):
x_attn, attn = self.attn(self.norm1(x), mask, True)
x = x + self.drop_path(x_attn)
lens_t = global_index_template.shape[1]

removed_index_search = None
if self.keep_ratio_search < 1 and (keep_ratio_search is None
or keep_ratio_search < 1):
keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search
x, global_index_search, removed_index_search = candidate_elimination(
attn, x, lens_t, keep_ratio_search, global_index_search,
ce_template_mask)

x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, global_index_template, global_index_search, removed_index_search, attn

+ 141
- 0
modelscope/models/cv/video_single_object_tracking/models/layers/head.py View File

@@ -0,0 +1,141 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch
import torch.nn as nn


def conv(in_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True))


class CenterPredictor(
nn.Module, ):

def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16):
super(CenterPredictor, self).__init__()
self.feat_sz = feat_sz
self.stride = stride
self.img_sz = self.feat_sz * self.stride

# corner predict
self.conv1_ctr = conv(inplanes, channel)
self.conv2_ctr = conv(channel, channel // 2)
self.conv3_ctr = conv(channel // 2, channel // 4)
self.conv4_ctr = conv(channel // 4, channel // 8)
self.conv5_ctr = nn.Conv2d(channel // 8, 1, kernel_size=1)

# offset regress
self.conv1_offset = conv(inplanes, channel)
self.conv2_offset = conv(channel, channel // 2)
self.conv3_offset = conv(channel // 2, channel // 4)
self.conv4_offset = conv(channel // 4, channel // 8)
self.conv5_offset = nn.Conv2d(channel // 8, 2, kernel_size=1)

# size regress
self.conv1_size = conv(inplanes, channel)
self.conv2_size = conv(channel, channel // 2)
self.conv3_size = conv(channel // 2, channel // 4)
self.conv4_size = conv(channel // 4, channel // 8)
self.conv5_size = nn.Conv2d(channel // 8, 2, kernel_size=1)

for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def forward(self, x, gt_score_map=None):
""" Forward pass with input x. """
score_map_ctr, size_map, offset_map = self.get_score_map(x)

# assert gt_score_map is None
if gt_score_map is None:
bbox = self.cal_bbox(score_map_ctr, size_map, offset_map)
else:
bbox = self.cal_bbox(
gt_score_map.unsqueeze(1), size_map, offset_map)

return score_map_ctr, bbox, size_map, offset_map

def cal_bbox(self,
score_map_ctr,
size_map,
offset_map,
return_score=False):
max_score, idx = torch.max(
score_map_ctr.flatten(1), dim=1, keepdim=True)
idx_y = idx // self.feat_sz
idx_x = idx % self.feat_sz

idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1)
size = size_map.flatten(2).gather(dim=2, index=idx)
offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)

# cx, cy, w, h
bbox = torch.cat(
[(idx_x.to(torch.float) + offset[:, :1]) / self.feat_sz,
(idx_y.to(torch.float) + offset[:, 1:]) / self.feat_sz,
size.squeeze(-1)],
dim=1)

if return_score:
return bbox, max_score
return bbox

def get_score_map(self, x):

def _sigmoid(x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
return y

# ctr branch
x_ctr1 = self.conv1_ctr(x)
x_ctr2 = self.conv2_ctr(x_ctr1)
x_ctr3 = self.conv3_ctr(x_ctr2)
x_ctr4 = self.conv4_ctr(x_ctr3)
score_map_ctr = self.conv5_ctr(x_ctr4)

# offset branch
x_offset1 = self.conv1_offset(x)
x_offset2 = self.conv2_offset(x_offset1)
x_offset3 = self.conv3_offset(x_offset2)
x_offset4 = self.conv4_offset(x_offset3)
score_map_offset = self.conv5_offset(x_offset4)

# size branch
x_size1 = self.conv1_size(x)
x_size2 = self.conv2_size(x_size1)
x_size3 = self.conv3_size(x_size2)
x_size4 = self.conv4_size(x_size3)
score_map_size = self.conv5_size(x_size4)
return _sigmoid(score_map_ctr), _sigmoid(
score_map_size), score_map_offset


def build_box_head(cfg, hidden_dim):
stride = cfg.MODEL.BACKBONE.STRIDE

if cfg.MODEL.HEAD.TYPE == 'CENTER':
in_channel = hidden_dim
out_channel = cfg.MODEL.HEAD.NUM_CHANNELS
feat_sz = int(cfg.DATA.SEARCH.SIZE / stride)
center_head = CenterPredictor(
inplanes=in_channel,
channel=out_channel,
feat_sz=feat_sz,
stride=stride)
return center_head
else:
raise ValueError('HEAD TYPE %s is not supported.'
% cfg.MODEL.HEAD_TYPE)

+ 37
- 0
modelscope/models/cv/video_single_object_tracking/models/layers/patch_embed.py View File

@@ -0,0 +1,37 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch.nn as nn
from timm.models.layers import to_2tuple


class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""

def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x

+ 0
- 0
modelscope/models/cv/video_single_object_tracking/models/ostrack/__init__.py View File


+ 93
- 0
modelscope/models/cv/video_single_object_tracking/models/ostrack/base_backbone.py View File

@@ -0,0 +1,93 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch.nn as nn
from timm.models.layers import to_2tuple

from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \
PatchEmbed


class BaseBackbone(nn.Module):

def __init__(self):
super().__init__()

# for original ViT
self.pos_embed = None
self.img_size = [224, 224]
self.patch_size = 16
self.embed_dim = 384

self.cat_mode = 'direct'

self.pos_embed_z = None
self.pos_embed_x = None

self.template_segment_pos_embed = None
self.search_segment_pos_embed = None

self.return_stage = [2, 5, 8, 11]

def finetune_track(self, cfg, patch_start_index=1):

search_size = to_2tuple(cfg.DATA.SEARCH.SIZE)
template_size = to_2tuple(cfg.DATA.TEMPLATE.SIZE)
new_patch_size = cfg.MODEL.BACKBONE.STRIDE

self.cat_mode = cfg.MODEL.BACKBONE.CAT_MODE

# resize patch embedding
if new_patch_size != self.patch_size:
print(
'Inconsistent Patch Size With The Pretrained Weights, Interpolate The Weight!'
)
old_patch_embed = {}
for name, param in self.patch_embed.named_parameters():
if 'weight' in name:
param = nn.functional.interpolate(
param,
size=(new_patch_size, new_patch_size),
mode='bicubic',
align_corners=False)
param = nn.Parameter(param)
old_patch_embed[name] = param
self.patch_embed = PatchEmbed(
img_size=self.img_size,
patch_size=new_patch_size,
in_chans=3,
embed_dim=self.embed_dim)
self.patch_embed.proj.bias = old_patch_embed['proj.bias']
self.patch_embed.proj.weight = old_patch_embed['proj.weight']

# for patch embedding
patch_pos_embed = self.pos_embed[:, patch_start_index:, :]
patch_pos_embed = patch_pos_embed.transpose(1, 2)
B, E, Q = patch_pos_embed.shape
P_H, P_W = self.img_size[0] // self.patch_size, self.img_size[
1] // self.patch_size
patch_pos_embed = patch_pos_embed.view(B, E, P_H, P_W)

# for search region
H, W = search_size
new_P_H, new_P_W = H // new_patch_size, W // new_patch_size
search_patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_P_H, new_P_W),
mode='bicubic',
align_corners=False)
search_patch_pos_embed = search_patch_pos_embed.flatten(2).transpose(
1, 2)

# for template region
H, W = template_size
new_P_H, new_P_W = H // new_patch_size, W // new_patch_size
template_patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_P_H, new_P_W),
mode='bicubic',
align_corners=False)
template_patch_pos_embed = template_patch_pos_embed.flatten(
2).transpose(1, 2)

self.pos_embed_z = nn.Parameter(template_patch_pos_embed)
self.pos_embed_x = nn.Parameter(search_patch_pos_embed)

+ 109
- 0
modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py View File

@@ -0,0 +1,109 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch
from torch import nn

from modelscope.models.cv.video_single_object_tracking.models.layers.head import \
build_box_head
from .vit_ce import vit_base_patch16_224_ce


class OSTrack(nn.Module):
""" This is the base class for OSTrack """

def __init__(self,
transformer,
box_head,
aux_loss=False,
head_type='CORNER'):
""" Initializes the model.
Parameters:
transformer: torch module of the transformer architecture.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.backbone = transformer
self.box_head = box_head

self.aux_loss = aux_loss
self.head_type = head_type
if head_type == 'CORNER' or head_type == 'CENTER':
self.feat_sz_s = int(box_head.feat_sz)
self.feat_len_s = int(box_head.feat_sz**2)

def forward(
self,
template: torch.Tensor,
search: torch.Tensor,
ce_template_mask=None,
ce_keep_rate=None,
):
x, aux_dict = self.backbone(
z=template,
x=search,
ce_template_mask=ce_template_mask,
ce_keep_rate=ce_keep_rate,
)

# Forward head
feat_last = x
if isinstance(x, list):
feat_last = x[-1]
out = self.forward_head(feat_last, None)

out.update(aux_dict)
out['backbone_feat'] = x
return out

def forward_head(self, cat_feature, gt_score_map=None):
"""
cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C)
"""
enc_opt = cat_feature[:, -self.
feat_len_s:] # encoder output for the search region (B, HW, C)
opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous()
bs, Nq, C, HW = opt.size()
opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s)

if self.head_type == 'CENTER':
# run the center head
score_map_ctr, bbox, size_map, offset_map = self.box_head(
opt_feat, gt_score_map)
outputs_coord = bbox
outputs_coord_new = outputs_coord.view(bs, Nq, 4)
out = {
'pred_boxes': outputs_coord_new,
'score_map': score_map_ctr,
'size_map': size_map,
'offset_map': offset_map
}
return out
else:
raise NotImplementedError


def build_ostrack(cfg):
if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce':
backbone = vit_base_patch16_224_ce(
False,
drop_path_rate=cfg.MODEL.BACKBONE.DROP_PATH_RATE,
ce_loc=cfg.MODEL.BACKBONE.CE_LOC,
ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO,
)
hidden_dim = backbone.embed_dim
patch_start_index = 1
else:
raise NotImplementedError

backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index)

box_head = build_box_head(cfg, hidden_dim)

model = OSTrack(
backbone,
box_head,
aux_loss=False,
head_type=cfg.MODEL.HEAD.TYPE,
)

return model

+ 24
- 0
modelscope/models/cv/video_single_object_tracking/models/ostrack/utils.py View File

@@ -0,0 +1,24 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch


def combine_tokens(template_tokens,
search_tokens,
mode='direct',
return_res=False):
if mode == 'direct':
merged_feature = torch.cat((template_tokens, search_tokens), dim=1)
else:
raise NotImplementedError

return merged_feature


def recover_tokens(merged_tokens, mode='direct'):
if mode == 'direct':
recovered_tokens = merged_tokens
else:
raise NotImplementedError

return recovered_tokens

+ 343
- 0
modelscope/models/cv/video_single_object_tracking/models/ostrack/vit_ce.py View File

@@ -0,0 +1,343 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
from functools import partial

import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp, to_2tuple

from modelscope.models.cv.video_single_object_tracking.models.layers.attn_blocks import \
CEBlock
from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \
PatchEmbed
from .base_backbone import BaseBackbone
from .utils import combine_tokens, recover_tokens


class Attention(nn.Module):

def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)


class Block(nn.Module):

def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)


class VisionTransformer(BaseBackbone):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""

def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
distilled=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU

self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim)


class VisionTransformerCE(VisionTransformer):
""" Vision Transformer with candidate elimination (CE) module

A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929

Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""

def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
distilled=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
ce_loc=None,
ce_keep_ratio=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
if isinstance(img_size, tuple):
self.img_size = img_size
else:
self.img_size = to_2tuple(img_size)
self.patch_size = patch_size
self.in_chans = in_chans

self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU

self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(
1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
blocks = []
ce_index = 0
self.ce_loc = ce_loc
for i in range(depth):
ce_keep_ratio_i = 1.0
if ce_loc is not None and i in ce_loc:
ce_keep_ratio_i = ce_keep_ratio[ce_index]
ce_index += 1

blocks.append(
CEBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
keep_ratio_search=ce_keep_ratio_i))

self.blocks = nn.Sequential(*blocks)
self.norm = norm_layer(embed_dim)

def forward_features(
self,
z,
x,
mask_x=None,
ce_template_mask=None,
ce_keep_rate=None,
):
B = x.shape[0]

x = self.patch_embed(x)
z = self.patch_embed(z)

z += self.pos_embed_z
x += self.pos_embed_x

x = combine_tokens(z, x, mode=self.cat_mode)

x = self.pos_drop(x)

lens_z = self.pos_embed_z.shape[1]
lens_x = self.pos_embed_x.shape[1]

global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device)
global_index_t = global_index_t.repeat(B, 1)

global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device)
global_index_s = global_index_s.repeat(B, 1)
removed_indexes_s = []
for i, blk in enumerate(self.blocks):
x, global_index_t, global_index_s, removed_index_s, attn = \
blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate)

if self.ce_loc is not None and i in self.ce_loc:
removed_indexes_s.append(removed_index_s)

x = self.norm(x)
lens_x_new = global_index_s.shape[1]
lens_z_new = global_index_t.shape[1]

z = x[:, :lens_z_new]
x = x[:, lens_z_new:]

if removed_indexes_s and removed_indexes_s[0] is not None:
removed_indexes_cat = torch.cat(removed_indexes_s, dim=1)

pruned_lens_x = lens_x - lens_x_new
pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]],
device=x.device)
x = torch.cat([x, pad_x], dim=1)
index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1)
# recover original token order
C = x.shape[-1]
x = torch.zeros_like(x).scatter_(
dim=1,
index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64),
src=x)

x = recover_tokens(x, mode=self.cat_mode)

# re-concatenate with the template, which may be further used by other modules
x = torch.cat([z, x], dim=1)

aux_dict = {
'attn': attn,
'removed_indexes_s': removed_indexes_s, # used for visualization
}

return x, aux_dict

def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None):

x, aux_dict = self.forward_features(
z,
x,
ce_template_mask=ce_template_mask,
ce_keep_rate=ce_keep_rate,
)

return x, aux_dict


def _create_vision_transformer(pretrained=False, **kwargs):
model = VisionTransformerCE(**kwargs)
return model


def vit_base_patch16_224_ce(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(pretrained=pretrained, **model_kwargs)
return model

+ 0
- 0
modelscope/models/cv/video_single_object_tracking/tracker/__init__.py View File


+ 139
- 0
modelscope/models/cv/video_single_object_tracking/tracker/ostrack.py View File

@@ -0,0 +1,139 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import torch

from modelscope.models.cv.video_single_object_tracking.config.ostrack import \
cfg
from modelscope.models.cv.video_single_object_tracking.models.ostrack.ostrack import \
build_ostrack
from modelscope.models.cv.video_single_object_tracking.utils.utils import (
Preprocessor, clip_box, generate_mask_cond, hann2d, sample_target,
transform_image_to_crop)


class OSTrack():

def __init__(self, ckpt_path, device):
network = build_ostrack(cfg)
network.load_state_dict(
torch.load(ckpt_path, map_location='cpu')['net'], strict=True)
self.cfg = cfg
if device.type == 'cuda':
self.network = network.to(device)
else:
self.network = network
self.network.eval()
self.preprocessor = Preprocessor(device)
self.state = None

self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
# motion constrain
if device.type == 'cuda':
self.output_window = hann2d(
torch.tensor([self.feat_sz, self.feat_sz]).long(),
centered=True).to(device)
else:
self.output_window = hann2d(
torch.tensor([self.feat_sz, self.feat_sz]).long(),
centered=True)
self.frame_id = 0
# for save boxes from all queries
self.z_dict1 = {}

def initialize(self, image, info: dict):
# forward the template once
z_patch_arr, resize_factor, z_amask_arr = sample_target(
image,
info['init_bbox'],
self.cfg.TEST.TEMPLATE_FACTOR,
output_sz=self.cfg.TEST.TEMPLATE_SIZE)
self.z_patch_arr = z_patch_arr
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
with torch.no_grad():
self.z_dict1 = template

self.box_mask_z = None
if self.cfg.MODEL.BACKBONE.CE_LOC:
template_bbox = self.transform_bbox_to_crop(
info['init_bbox'], resize_factor,
template.tensors.device).squeeze(1)
self.box_mask_z = generate_mask_cond(self.cfg, 1,
template.tensors.device,
template_bbox)

# save states
self.state = info['init_bbox']
self.frame_id = 0

def track(self, image, info: dict = None):
H, W, _ = image.shape
self.frame_id += 1
x_patch_arr, resize_factor, x_amask_arr = sample_target(
image,
self.state,
self.cfg.TEST.SEARCH_FACTOR,
output_sz=self.cfg.TEST.SEARCH_SIZE) # (x1, y1, w, h)
search = self.preprocessor.process(x_patch_arr, x_amask_arr)

with torch.no_grad():
x_dict = search
# merge the template and the search
# run the transformer
out_dict = self.network.forward(
template=self.z_dict1.tensors,
search=x_dict.tensors,
ce_template_mask=self.box_mask_z)

# add hann windows
pred_score_map = out_dict['score_map']
response = self.output_window * pred_score_map
pred_boxes = self.network.box_head.cal_bbox(response,
out_dict['size_map'],
out_dict['offset_map'])
pred_boxes = pred_boxes.view(-1, 4)
# Baseline: Take the mean of all pred boxes as the final result
pred_box = (pred_boxes.mean(dim=0) * self.cfg.TEST.SEARCH_SIZE
/ resize_factor).tolist() # (cx, cy, w, h) [0,1]
# get the final box result
self.state = clip_box(
self.map_box_back(pred_box, resize_factor), H, W, margin=10)

x1, y1, w, h = self.state
x2 = x1 + w
y2 = y1 + h
return {'target_bbox': [x1, y1, x2, y2]}

def map_box_back(self, pred_box: list, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[
1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box
half_side = 0.5 * self.cfg.TEST.SEARCH_SIZE / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]

def transform_bbox_to_crop(self,
box_in,
resize_factor,
device,
box_extract=None,
crop_type='template'):
if crop_type == 'template':
crop_sz = torch.Tensor(
[self.cfg.TEST.TEMPLATE_SIZE, self.cfg.TEST.TEMPLATE_SIZE])
elif crop_type == 'search':
crop_sz = torch.Tensor(
[self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE])
else:
raise NotImplementedError

box_in = torch.tensor(box_in)
if box_extract is None:
box_extract = box_in
else:
box_extract = torch.tensor(box_extract)
template_bbox = transform_image_to_crop(
box_in, box_extract, resize_factor, crop_sz, normalize=True)
template_bbox = template_bbox.view(1, 1, 4).to(device)

return template_bbox

+ 0
- 0
modelscope/models/cv/video_single_object_tracking/utils/__init__.py View File


+ 261
- 0
modelscope/models/cv/video_single_object_tracking/utils/utils.py View File

@@ -0,0 +1,261 @@
# The implementation is also open-sourced by the authors as OSTrack, and is available publicly on
# https://github.com/botaoye/OSTrack/
import math
from typing import Optional

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor


def hann1d(sz: int, centered=True) -> torch.Tensor:
"""1D cosine window."""
if centered:
return 0.5 * (1 - torch.cos(
(2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float()))
w = 0.5 * (1 + torch.cos(
(2 * math.pi / (sz + 2)) * torch.arange(0, sz // 2 + 1).float()))
return torch.cat([w, w[1:sz - sz // 2].flip((0, ))])


def hann2d(sz: torch.Tensor, centered=True) -> torch.Tensor:
"""2D cosine window."""
return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(
sz[1].item(), centered).reshape(1, 1, 1, -1)


class NestedTensor(object):

def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask


class Preprocessor(object):

def __init__(self, device: str):
self.device = device
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1))
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1))
if 'cuda' == self.device.type:
self.mean = self.mean.to(self.device)
self.std = self.std.to(self.device)

def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
# Deal with the image patch
if 'cuda' == self.device.type:
img_tensor = torch.tensor(img_arr).to(self.device).float().permute(
(2, 0, 1)).unsqueeze(dim=0)
else:
img_tensor = torch.tensor(img_arr).float().permute(
(2, 0, 1)).unsqueeze(dim=0)
img_tensor_norm = (
(img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)

# Deal with the attention mask
if 'cuda' == self.device.type:
amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).to(
self.device).unsqueeze(dim=0) # (1,H,W)
else:
amask_tensor = torch.from_numpy(amask_arr).to(
torch.bool).unsqueeze(dim=0) # (1,H,W)
return NestedTensor(img_tensor_norm, amask_tensor)


def clip_box(box: list, H, W, margin=0):
x1, y1, w, h = box
x2, y2 = x1 + w, y1 + h
x1 = min(max(0, x1), W - margin)
x2 = min(max(margin, x2), W)
y1 = min(max(0, y1), H - margin)
y2 = min(max(margin, y2), H)
w = max(margin, x2 - x1)
h = max(margin, y2 - y1)
if isinstance(x1, torch.Tensor):
x1 = x1.item()
y1 = y1.item()
w = w.item()
h = h.item()
return [x1, y1, w, h]


def generate_mask_cond(cfg, bs, device, gt_bbox):
template_size = cfg.DATA.TEMPLATE.SIZE
stride = cfg.MODEL.BACKBONE.STRIDE
template_feat_size = template_size // stride

if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT':
if template_feat_size == 8:
index = slice(3, 4)
elif template_feat_size == 12:
index = slice(5, 6)
elif template_feat_size == 7:
index = slice(3, 4)
elif template_feat_size == 14:
index = slice(6, 7)
else:
raise NotImplementedError
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size],
device=device)
box_mask_z[:, index, index] = 1
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
else:
raise NotImplementedError

return box_mask_z


def sample_target(im,
target_bb,
search_area_factor,
output_sz=None,
mask=None):
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area

args:
im - cv image
target_bb - target box [x, y, w, h]
search_area_factor - Ratio of crop size to target size
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.

returns:
cv image - extracted crop
float - the factor by which the crop has been resized to make the crop size equal output_size
"""
if not isinstance(target_bb, list):
x, y, w, h = target_bb.tolist()
else:
x, y, w, h = target_bb
# Crop image
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)

if crop_sz < 1:
raise Exception('Too small bounding box.')

x1 = round(x + 0.5 * w - crop_sz * 0.5)
x2 = x1 + crop_sz

y1 = round(y + 0.5 * h - crop_sz * 0.5)
y2 = y1 + crop_sz

x1_pad = max(0, -x1)
x2_pad = max(x2 - im.shape[1] + 1, 0)

y1_pad = max(0, -y1)
y2_pad = max(y2 - im.shape[0] + 1, 0)

# Crop target
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
if mask is not None:
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]

# Pad
im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad,
x2_pad, cv2.BORDER_CONSTANT)
# deal with attention mask
H, W, _ = im_crop_padded.shape
att_mask = np.ones((H, W))
end_x, end_y = -x2_pad, -y2_pad
if y2_pad == 0:
end_y = None
if x2_pad == 0:
end_x = None
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
if mask is not None:
mask_crop_padded = F.pad(
mask_crop,
pad=(x1_pad, x2_pad, y1_pad, y2_pad),
mode='constant',
value=0)

if output_sz is not None:
resize_factor = output_sz / crop_sz
im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
att_mask = cv2.resize(att_mask,
(output_sz, output_sz)).astype(np.bool_)
if mask is None:
return im_crop_padded, resize_factor, att_mask
mask_crop_padded = \
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz),
mode='bilinear', align_corners=False)[0, 0]
return im_crop_padded, resize_factor, att_mask, mask_crop_padded

else:
if mask is None:
return im_crop_padded, att_mask.astype(np.bool_), 1.0
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded


def transform_image_to_crop(box_in: torch.Tensor,
box_extract: torch.Tensor,
resize_factor: float,
crop_sz: torch.Tensor,
normalize=False) -> torch.Tensor:
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
args:
box_in - the box for which the co-ordinates are to be transformed
box_extract - the box about which the image crop has been extracted.
resize_factor - the ratio between the original image scale and the scale of the image crop
crop_sz - size of the cropped image

returns:
torch.Tensor - transformed co-ordinates of box_in
"""
box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]

box_in_center = box_in[0:2] + 0.5 * box_in[2:4]

box_out_center = (crop_sz - 1) / 2 + (box_in_center
- box_extract_center) * resize_factor
box_out_wh = box_in[2:4] * resize_factor

box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
if normalize:
return box_out / crop_sz[0]
else:
return box_out


def check_box(box: list, image_height, image_width) -> bool:
""" To check whether the box is within the image range or not
args:
box - the bounding box in the form of [x1, y1, x2, y2]
image_height - the height of the image
image_width - the width of the image

returns:
bool - if box is valid, return True. Otherwise, return False
"""
assert len(box) == 4, 'box must be in the form of: [x1, y1, x2, y2]'
if box[0] < 0 or box[0] >= image_width:
return False
if box[2] < 0 or box[2] >= image_width:
return False
if box[1] < 0 or box[1] >= image_height:
return False
if box[3] < 0 or box[3] >= image_height:
return False
return True


def show_tracking_result(video_in_path, bboxes, video_save_path):
cap = cv2.VideoCapture(video_in_path)
for i in range(len(bboxes)):
box = bboxes[i]
success, frame = cap.read()
if success is False:
raise Exception(video_in_path,
' can not be correctly decoded by OpenCV.')
if i == 0:
size = (frame.shape[1], frame.shape[0])
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
video_writer = cv2.VideoWriter(video_save_path, fourcc,
cap.get(cv2.CAP_PROP_FPS), size,
True)
cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0),
5)
video_writer.write(frame)
video_writer.release
cap.release()

+ 10
- 0
modelscope/outputs.py View File

@@ -188,6 +188,16 @@ TASK_OUTPUTS = {
Tasks.body_2d_keypoints:
[OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES],

# video single object tracking result for single video
# {
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ]
# }
Tasks.video_single_object_tracking: [OutputKeys.BOXES],

# live category recognition result for single video
# {
# "scores": [0.885272, 0.014790631, 0.014558001],


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -130,6 +130,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_skin-retouching'),
Tasks.crowd_counting: (Pipelines.crowd_counting,
'damo/cv_hrnet_crowd-counting_dcanet'),
Tasks.video_single_object_tracking:
(Pipelines.video_single_object_tracking,
'damo/cv_vitb_video-single-object-tracking_ostrack'),
}




+ 80
- 0
modelscope/pipelines/cv/video_single_object_tracking_pipeline.py View File

@@ -0,0 +1,80 @@
import os.path as osp
from typing import Any, Dict

import cv2

from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_single_object_tracking.config.ostrack import \
cfg
from modelscope.models.cv.video_single_object_tracking.tracker.ostrack import \
OSTrack
from modelscope.models.cv.video_single_object_tracking.utils.utils import \
check_box
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.video_single_object_tracking,
module_name=Pipelines.video_single_object_tracking)
class VideoSingleObjectTrackingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a single object tracking pipeline
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
self.cfg = cfg
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE)
logger.info(f'loading model from {ckpt_path}')
self.tracker = OSTrack(ckpt_path, self.device)
logger.info('init tracker done')

def preprocess(self, input) -> Input:
self.video_path = input[0]
self.init_bbox = input[1]
return input

def forward(self, input: Input) -> Dict[str, Any]:
output_boxes = []
cap = cv2.VideoCapture(self.video_path)
success, frame = cap.read()
if success is False:
raise Exception(
'modelscope error: %s can not be decoded by OpenCV.' %
(self.video_path))

init_box = self.init_bbox
frame_h, frame_w = frame.shape[0:2]
if not check_box(init_box, frame_h, frame_w):
raise Exception('modelscope error: init_box out of image range ',
init_box)
output_boxes.append(init_box.copy())
init_box[2] = init_box[2] - init_box[0]
init_box[3] = init_box[3] - init_box[1]
self.tracker.initialize(frame, {'init_bbox': init_box})
logger.info('init bbox done')

while True:
ret, frame = cap.read()
if frame is None:
break
out = self.tracker.track(frame)
state = [int(s) for s in out['target_bbox']]
output_boxes.append(state)
cap.release()
logger.info('tracking process done')

return {
OutputKeys.BOXES: output_boxes,
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 3
- 0
modelscope/utils/constant.py View File

@@ -62,6 +62,9 @@ class CVTasks(object):
virtual_try_on = 'virtual-try-on'
crowd_counting = 'crowd-counting'

# video related
video_single_object_tracking = 'video-single-object-tracking'


class NLPTasks(object):
# nlp tasks


+ 39
- 0
tests/pipelines/test_video_single_object_tracking.py View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.models.cv.video_single_object_tracking.utils.utils import \
show_tracking_result
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class SingleObjectTracking(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_vitb_video-single-object-tracking_ostrack'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_end2end(self):
video_single_object_tracking = pipeline(
Tasks.video_single_object_tracking, model=self.model_id)
video_path = 'data/test/videos/dog.avi'
init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2]
result = video_single_object_tracking((video_path, init_bbox))
print('result is : ', result[OutputKeys.BOXES])
show_tracking_result(video_path, result[OutputKeys.BOXES],
'./tracking_result.avi')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub_default_model(self):
video_single_object_tracking = pipeline(
Tasks.video_single_object_tracking)
video_path = 'data/test/videos/dog.avi'
init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2]
result = video_single_object_tracking((video_path, init_bbox))
print('result is : ', result[OutputKeys.BOXES])


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save