Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9781254 * support single object trackingmaster
| @@ -4,3 +4,4 @@ | |||||
| *.wav filter=lfs diff=lfs merge=lfs -text | *.wav filter=lfs diff=lfs merge=lfs -text | ||||
| *.JPEG filter=lfs diff=lfs merge=lfs -text | *.JPEG filter=lfs diff=lfs merge=lfs -text | ||||
| *.jpeg filter=lfs diff=lfs merge=lfs -text | *.jpeg filter=lfs diff=lfs merge=lfs -text | ||||
| *.avi filter=lfs diff=lfs merge=lfs -text | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:469090fb217a34a2c096cfd42c251da69dca9fcd1a3c1faae7d29183c1816c14 | |||||
| size 12834294 | |||||
| @@ -111,6 +111,7 @@ class Pipelines(object): | |||||
| skin_retouching = 'unet-skin-retouching' | skin_retouching = 'unet-skin-retouching' | ||||
| tinynas_classification = 'tinynas-classification' | tinynas_classification = 'tinynas-classification' | ||||
| crowd_counting = 'hrnet-crowd-counting' | crowd_counting = 'hrnet-crowd-counting' | ||||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| @@ -6,4 +6,4 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
| image_portrait_enhancement, image_to_image_generation, | image_portrait_enhancement, image_to_image_generation, | ||||
| image_to_image_translation, object_detection, | image_to_image_translation, object_detection, | ||||
| product_retrieval_embedding, salient_detection, | product_retrieval_embedding, salient_detection, | ||||
| super_resolution, virual_tryon) | |||||
| super_resolution, video_single_object_tracking, virual_tryon) | |||||
| @@ -0,0 +1,39 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| from easydict import EasyDict as edict | |||||
| cfg = edict() | |||||
| # MODEL | |||||
| cfg.MODEL = edict() | |||||
| # MODEL.BACKBONE | |||||
| cfg.MODEL.BACKBONE = edict() | |||||
| cfg.MODEL.BACKBONE.TYPE = 'vit_base_patch16_224_ce' | |||||
| cfg.MODEL.BACKBONE.STRIDE = 16 | |||||
| cfg.MODEL.BACKBONE.CAT_MODE = 'direct' | |||||
| cfg.MODEL.BACKBONE.DROP_PATH_RATE = 0.1 | |||||
| cfg.MODEL.BACKBONE.CE_LOC = [3, 6, 9] | |||||
| cfg.MODEL.BACKBONE.CE_KEEP_RATIO = [0.7, 0.7, 0.7] | |||||
| cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'CTR_POINT' | |||||
| # MODEL.HEAD | |||||
| cfg.MODEL.HEAD = edict() | |||||
| cfg.MODEL.HEAD.TYPE = 'CENTER' | |||||
| cfg.MODEL.HEAD.NUM_CHANNELS = 256 | |||||
| # DATA | |||||
| cfg.DATA = edict() | |||||
| cfg.DATA.MEAN = [0.485, 0.456, 0.406] | |||||
| cfg.DATA.STD = [0.229, 0.224, 0.225] | |||||
| cfg.DATA.SEARCH = edict() | |||||
| cfg.DATA.SEARCH.SIZE = 384 | |||||
| cfg.DATA.TEMPLATE = edict() | |||||
| cfg.DATA.TEMPLATE.SIZE = 192 | |||||
| # TEST | |||||
| cfg.TEST = edict() | |||||
| cfg.TEST.TEMPLATE_FACTOR = 2.0 | |||||
| cfg.TEST.TEMPLATE_SIZE = 192 | |||||
| cfg.TEST.SEARCH_FACTOR = 5.0 | |||||
| cfg.TEST.SEARCH_SIZE = 384 | |||||
| @@ -0,0 +1,54 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch.nn as nn | |||||
| class Attention(nn.Module): | |||||
| def __init__(self, | |||||
| dim, | |||||
| num_heads=8, | |||||
| qkv_bias=False, | |||||
| attn_drop=0., | |||||
| proj_drop=0., | |||||
| rpe=False, | |||||
| z_size=7, | |||||
| x_size=14): | |||||
| super().__init__() | |||||
| self.num_heads = num_heads | |||||
| head_dim = dim // num_heads | |||||
| self.scale = head_dim**-0.5 | |||||
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||||
| self.attn_drop = nn.Dropout(attn_drop) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| self.proj_drop = nn.Dropout(proj_drop) | |||||
| def forward(self, x, mask=None, return_attention=False): | |||||
| # x: B, N, C | |||||
| # mask: [B, N, ] torch.bool | |||||
| B, N, C = x.shape | |||||
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, | |||||
| C // self.num_heads).permute(2, 0, 3, 1, 4) | |||||
| q, k, v = qkv.unbind( | |||||
| 0) # make torchscript happy (cannot use tensor as tuple) | |||||
| attn = (q @ k.transpose(-2, -1)) * self.scale | |||||
| if mask is not None: | |||||
| attn = attn.masked_fill( | |||||
| mask.unsqueeze(1).unsqueeze(2), | |||||
| float('-inf'), | |||||
| ) | |||||
| attn = attn.softmax(dim=-1) | |||||
| attn = self.attn_drop(attn) | |||||
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |||||
| x = self.proj(x) | |||||
| x = self.proj_drop(x) | |||||
| if return_attention: | |||||
| return x, attn | |||||
| else: | |||||
| return x | |||||
| @@ -0,0 +1,129 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from timm.models.layers import DropPath, Mlp | |||||
| from .attn import Attention | |||||
| def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, | |||||
| lens_t: int, keep_ratio: float, | |||||
| global_index: torch.Tensor, | |||||
| box_mask_z: torch.Tensor): | |||||
| """ | |||||
| Eliminate potential background candidates for computation reduction and noise cancellation. | |||||
| Args: | |||||
| attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights | |||||
| tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens | |||||
| lens_t (int): length of template | |||||
| keep_ratio (float): keep ratio of search region tokens (candidates) | |||||
| global_index (torch.Tensor): global index of search region tokens | |||||
| box_mask_z (torch.Tensor): template mask used to accumulate attention weights | |||||
| Returns: | |||||
| tokens_new (torch.Tensor): tokens after candidate elimination | |||||
| keep_index (torch.Tensor): indices of kept search region tokens | |||||
| removed_index (torch.Tensor): indices of removed search region tokens | |||||
| """ | |||||
| lens_s = attn.shape[-1] - lens_t | |||||
| bs, hn, _, _ = attn.shape | |||||
| lens_keep = math.ceil(keep_ratio * lens_s) | |||||
| if lens_keep == lens_s: | |||||
| return tokens, global_index, None | |||||
| attn_t = attn[:, :, :lens_t, lens_t:] | |||||
| if box_mask_z is not None: | |||||
| box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand( | |||||
| -1, attn_t.shape[1], -1, attn_t.shape[-1]) | |||||
| attn_t = attn_t[box_mask_z] | |||||
| attn_t = attn_t.view(bs, hn, -1, lens_s) | |||||
| attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s | |||||
| else: | |||||
| attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s | |||||
| # use sort instead of topk, due to the speed issue | |||||
| # https://github.com/pytorch/pytorch/issues/22812 | |||||
| sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True) | |||||
| _, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep] | |||||
| _, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:] | |||||
| keep_index = global_index.gather(dim=1, index=topk_idx) | |||||
| removed_index = global_index.gather(dim=1, index=non_topk_idx) | |||||
| # separate template and search tokens | |||||
| tokens_t = tokens[:, :lens_t] | |||||
| tokens_s = tokens[:, lens_t:] | |||||
| # obtain the attentive and inattentive tokens | |||||
| B, L, C = tokens_s.shape | |||||
| attentive_tokens = tokens_s.gather( | |||||
| dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C)) | |||||
| # concatenate these tokens | |||||
| tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1) | |||||
| return tokens_new, keep_index, removed_index | |||||
| class CEBlock(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| dim, | |||||
| num_heads, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=False, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| act_layer=nn.GELU, | |||||
| norm_layer=nn.LayerNorm, | |||||
| keep_ratio_search=1.0, | |||||
| ): | |||||
| super().__init__() | |||||
| self.norm1 = norm_layer(dim) | |||||
| self.attn = Attention( | |||||
| dim, | |||||
| num_heads=num_heads, | |||||
| qkv_bias=qkv_bias, | |||||
| attn_drop=attn_drop, | |||||
| proj_drop=drop) | |||||
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| self.norm2 = norm_layer(dim) | |||||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||||
| self.mlp = Mlp( | |||||
| in_features=dim, | |||||
| hidden_features=mlp_hidden_dim, | |||||
| act_layer=act_layer, | |||||
| drop=drop) | |||||
| self.keep_ratio_search = keep_ratio_search | |||||
| def forward(self, | |||||
| x, | |||||
| global_index_template, | |||||
| global_index_search, | |||||
| mask=None, | |||||
| ce_template_mask=None, | |||||
| keep_ratio_search=None): | |||||
| x_attn, attn = self.attn(self.norm1(x), mask, True) | |||||
| x = x + self.drop_path(x_attn) | |||||
| lens_t = global_index_template.shape[1] | |||||
| removed_index_search = None | |||||
| if self.keep_ratio_search < 1 and (keep_ratio_search is None | |||||
| or keep_ratio_search < 1): | |||||
| keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search | |||||
| x, global_index_search, removed_index_search = candidate_elimination( | |||||
| attn, x, lens_t, keep_ratio_search, global_index_search, | |||||
| ce_template_mask) | |||||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |||||
| return x, global_index_template, global_index_search, removed_index_search, attn | |||||
| @@ -0,0 +1,141 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| def conv(in_planes, | |||||
| out_planes, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| dilation=1): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_planes, | |||||
| out_planes, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| bias=True), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True)) | |||||
| class CenterPredictor( | |||||
| nn.Module, ): | |||||
| def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16): | |||||
| super(CenterPredictor, self).__init__() | |||||
| self.feat_sz = feat_sz | |||||
| self.stride = stride | |||||
| self.img_sz = self.feat_sz * self.stride | |||||
| # corner predict | |||||
| self.conv1_ctr = conv(inplanes, channel) | |||||
| self.conv2_ctr = conv(channel, channel // 2) | |||||
| self.conv3_ctr = conv(channel // 2, channel // 4) | |||||
| self.conv4_ctr = conv(channel // 4, channel // 8) | |||||
| self.conv5_ctr = nn.Conv2d(channel // 8, 1, kernel_size=1) | |||||
| # offset regress | |||||
| self.conv1_offset = conv(inplanes, channel) | |||||
| self.conv2_offset = conv(channel, channel // 2) | |||||
| self.conv3_offset = conv(channel // 2, channel // 4) | |||||
| self.conv4_offset = conv(channel // 4, channel // 8) | |||||
| self.conv5_offset = nn.Conv2d(channel // 8, 2, kernel_size=1) | |||||
| # size regress | |||||
| self.conv1_size = conv(inplanes, channel) | |||||
| self.conv2_size = conv(channel, channel // 2) | |||||
| self.conv3_size = conv(channel // 2, channel // 4) | |||||
| self.conv4_size = conv(channel // 4, channel // 8) | |||||
| self.conv5_size = nn.Conv2d(channel // 8, 2, kernel_size=1) | |||||
| for p in self.parameters(): | |||||
| if p.dim() > 1: | |||||
| nn.init.xavier_uniform_(p) | |||||
| def forward(self, x, gt_score_map=None): | |||||
| """ Forward pass with input x. """ | |||||
| score_map_ctr, size_map, offset_map = self.get_score_map(x) | |||||
| # assert gt_score_map is None | |||||
| if gt_score_map is None: | |||||
| bbox = self.cal_bbox(score_map_ctr, size_map, offset_map) | |||||
| else: | |||||
| bbox = self.cal_bbox( | |||||
| gt_score_map.unsqueeze(1), size_map, offset_map) | |||||
| return score_map_ctr, bbox, size_map, offset_map | |||||
| def cal_bbox(self, | |||||
| score_map_ctr, | |||||
| size_map, | |||||
| offset_map, | |||||
| return_score=False): | |||||
| max_score, idx = torch.max( | |||||
| score_map_ctr.flatten(1), dim=1, keepdim=True) | |||||
| idx_y = idx // self.feat_sz | |||||
| idx_x = idx % self.feat_sz | |||||
| idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) | |||||
| size = size_map.flatten(2).gather(dim=2, index=idx) | |||||
| offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) | |||||
| # cx, cy, w, h | |||||
| bbox = torch.cat( | |||||
| [(idx_x.to(torch.float) + offset[:, :1]) / self.feat_sz, | |||||
| (idx_y.to(torch.float) + offset[:, 1:]) / self.feat_sz, | |||||
| size.squeeze(-1)], | |||||
| dim=1) | |||||
| if return_score: | |||||
| return bbox, max_score | |||||
| return bbox | |||||
| def get_score_map(self, x): | |||||
| def _sigmoid(x): | |||||
| y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) | |||||
| return y | |||||
| # ctr branch | |||||
| x_ctr1 = self.conv1_ctr(x) | |||||
| x_ctr2 = self.conv2_ctr(x_ctr1) | |||||
| x_ctr3 = self.conv3_ctr(x_ctr2) | |||||
| x_ctr4 = self.conv4_ctr(x_ctr3) | |||||
| score_map_ctr = self.conv5_ctr(x_ctr4) | |||||
| # offset branch | |||||
| x_offset1 = self.conv1_offset(x) | |||||
| x_offset2 = self.conv2_offset(x_offset1) | |||||
| x_offset3 = self.conv3_offset(x_offset2) | |||||
| x_offset4 = self.conv4_offset(x_offset3) | |||||
| score_map_offset = self.conv5_offset(x_offset4) | |||||
| # size branch | |||||
| x_size1 = self.conv1_size(x) | |||||
| x_size2 = self.conv2_size(x_size1) | |||||
| x_size3 = self.conv3_size(x_size2) | |||||
| x_size4 = self.conv4_size(x_size3) | |||||
| score_map_size = self.conv5_size(x_size4) | |||||
| return _sigmoid(score_map_ctr), _sigmoid( | |||||
| score_map_size), score_map_offset | |||||
| def build_box_head(cfg, hidden_dim): | |||||
| stride = cfg.MODEL.BACKBONE.STRIDE | |||||
| if cfg.MODEL.HEAD.TYPE == 'CENTER': | |||||
| in_channel = hidden_dim | |||||
| out_channel = cfg.MODEL.HEAD.NUM_CHANNELS | |||||
| feat_sz = int(cfg.DATA.SEARCH.SIZE / stride) | |||||
| center_head = CenterPredictor( | |||||
| inplanes=in_channel, | |||||
| channel=out_channel, | |||||
| feat_sz=feat_sz, | |||||
| stride=stride) | |||||
| return center_head | |||||
| else: | |||||
| raise ValueError('HEAD TYPE %s is not supported.' | |||||
| % cfg.MODEL.HEAD_TYPE) | |||||
| @@ -0,0 +1,37 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch.nn as nn | |||||
| from timm.models.layers import to_2tuple | |||||
| class PatchEmbed(nn.Module): | |||||
| """ 2D Image to Patch Embedding | |||||
| """ | |||||
| def __init__(self, | |||||
| img_size=224, | |||||
| patch_size=16, | |||||
| in_chans=3, | |||||
| embed_dim=768, | |||||
| norm_layer=None, | |||||
| flatten=True): | |||||
| super().__init__() | |||||
| img_size = to_2tuple(img_size) | |||||
| patch_size = to_2tuple(patch_size) | |||||
| self.img_size = img_size | |||||
| self.patch_size = patch_size | |||||
| self.grid_size = (img_size[0] // patch_size[0], | |||||
| img_size[1] // patch_size[1]) | |||||
| self.num_patches = self.grid_size[0] * self.grid_size[1] | |||||
| self.flatten = flatten | |||||
| self.proj = nn.Conv2d( | |||||
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||||
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |||||
| def forward(self, x): | |||||
| x = self.proj(x) | |||||
| if self.flatten: | |||||
| x = x.flatten(2).transpose(1, 2) # BCHW -> BNC | |||||
| x = self.norm(x) | |||||
| return x | |||||
| @@ -0,0 +1,93 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch.nn as nn | |||||
| from timm.models.layers import to_2tuple | |||||
| from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \ | |||||
| PatchEmbed | |||||
| class BaseBackbone(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| # for original ViT | |||||
| self.pos_embed = None | |||||
| self.img_size = [224, 224] | |||||
| self.patch_size = 16 | |||||
| self.embed_dim = 384 | |||||
| self.cat_mode = 'direct' | |||||
| self.pos_embed_z = None | |||||
| self.pos_embed_x = None | |||||
| self.template_segment_pos_embed = None | |||||
| self.search_segment_pos_embed = None | |||||
| self.return_stage = [2, 5, 8, 11] | |||||
| def finetune_track(self, cfg, patch_start_index=1): | |||||
| search_size = to_2tuple(cfg.DATA.SEARCH.SIZE) | |||||
| template_size = to_2tuple(cfg.DATA.TEMPLATE.SIZE) | |||||
| new_patch_size = cfg.MODEL.BACKBONE.STRIDE | |||||
| self.cat_mode = cfg.MODEL.BACKBONE.CAT_MODE | |||||
| # resize patch embedding | |||||
| if new_patch_size != self.patch_size: | |||||
| print( | |||||
| 'Inconsistent Patch Size With The Pretrained Weights, Interpolate The Weight!' | |||||
| ) | |||||
| old_patch_embed = {} | |||||
| for name, param in self.patch_embed.named_parameters(): | |||||
| if 'weight' in name: | |||||
| param = nn.functional.interpolate( | |||||
| param, | |||||
| size=(new_patch_size, new_patch_size), | |||||
| mode='bicubic', | |||||
| align_corners=False) | |||||
| param = nn.Parameter(param) | |||||
| old_patch_embed[name] = param | |||||
| self.patch_embed = PatchEmbed( | |||||
| img_size=self.img_size, | |||||
| patch_size=new_patch_size, | |||||
| in_chans=3, | |||||
| embed_dim=self.embed_dim) | |||||
| self.patch_embed.proj.bias = old_patch_embed['proj.bias'] | |||||
| self.patch_embed.proj.weight = old_patch_embed['proj.weight'] | |||||
| # for patch embedding | |||||
| patch_pos_embed = self.pos_embed[:, patch_start_index:, :] | |||||
| patch_pos_embed = patch_pos_embed.transpose(1, 2) | |||||
| B, E, Q = patch_pos_embed.shape | |||||
| P_H, P_W = self.img_size[0] // self.patch_size, self.img_size[ | |||||
| 1] // self.patch_size | |||||
| patch_pos_embed = patch_pos_embed.view(B, E, P_H, P_W) | |||||
| # for search region | |||||
| H, W = search_size | |||||
| new_P_H, new_P_W = H // new_patch_size, W // new_patch_size | |||||
| search_patch_pos_embed = nn.functional.interpolate( | |||||
| patch_pos_embed, | |||||
| size=(new_P_H, new_P_W), | |||||
| mode='bicubic', | |||||
| align_corners=False) | |||||
| search_patch_pos_embed = search_patch_pos_embed.flatten(2).transpose( | |||||
| 1, 2) | |||||
| # for template region | |||||
| H, W = template_size | |||||
| new_P_H, new_P_W = H // new_patch_size, W // new_patch_size | |||||
| template_patch_pos_embed = nn.functional.interpolate( | |||||
| patch_pos_embed, | |||||
| size=(new_P_H, new_P_W), | |||||
| mode='bicubic', | |||||
| align_corners=False) | |||||
| template_patch_pos_embed = template_patch_pos_embed.flatten( | |||||
| 2).transpose(1, 2) | |||||
| self.pos_embed_z = nn.Parameter(template_patch_pos_embed) | |||||
| self.pos_embed_x = nn.Parameter(search_patch_pos_embed) | |||||
| @@ -0,0 +1,109 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch | |||||
| from torch import nn | |||||
| from modelscope.models.cv.video_single_object_tracking.models.layers.head import \ | |||||
| build_box_head | |||||
| from .vit_ce import vit_base_patch16_224_ce | |||||
| class OSTrack(nn.Module): | |||||
| """ This is the base class for OSTrack """ | |||||
| def __init__(self, | |||||
| transformer, | |||||
| box_head, | |||||
| aux_loss=False, | |||||
| head_type='CORNER'): | |||||
| """ Initializes the model. | |||||
| Parameters: | |||||
| transformer: torch module of the transformer architecture. | |||||
| aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. | |||||
| """ | |||||
| super().__init__() | |||||
| self.backbone = transformer | |||||
| self.box_head = box_head | |||||
| self.aux_loss = aux_loss | |||||
| self.head_type = head_type | |||||
| if head_type == 'CORNER' or head_type == 'CENTER': | |||||
| self.feat_sz_s = int(box_head.feat_sz) | |||||
| self.feat_len_s = int(box_head.feat_sz**2) | |||||
| def forward( | |||||
| self, | |||||
| template: torch.Tensor, | |||||
| search: torch.Tensor, | |||||
| ce_template_mask=None, | |||||
| ce_keep_rate=None, | |||||
| ): | |||||
| x, aux_dict = self.backbone( | |||||
| z=template, | |||||
| x=search, | |||||
| ce_template_mask=ce_template_mask, | |||||
| ce_keep_rate=ce_keep_rate, | |||||
| ) | |||||
| # Forward head | |||||
| feat_last = x | |||||
| if isinstance(x, list): | |||||
| feat_last = x[-1] | |||||
| out = self.forward_head(feat_last, None) | |||||
| out.update(aux_dict) | |||||
| out['backbone_feat'] = x | |||||
| return out | |||||
| def forward_head(self, cat_feature, gt_score_map=None): | |||||
| """ | |||||
| cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) | |||||
| """ | |||||
| enc_opt = cat_feature[:, -self. | |||||
| feat_len_s:] # encoder output for the search region (B, HW, C) | |||||
| opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() | |||||
| bs, Nq, C, HW = opt.size() | |||||
| opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) | |||||
| if self.head_type == 'CENTER': | |||||
| # run the center head | |||||
| score_map_ctr, bbox, size_map, offset_map = self.box_head( | |||||
| opt_feat, gt_score_map) | |||||
| outputs_coord = bbox | |||||
| outputs_coord_new = outputs_coord.view(bs, Nq, 4) | |||||
| out = { | |||||
| 'pred_boxes': outputs_coord_new, | |||||
| 'score_map': score_map_ctr, | |||||
| 'size_map': size_map, | |||||
| 'offset_map': offset_map | |||||
| } | |||||
| return out | |||||
| else: | |||||
| raise NotImplementedError | |||||
| def build_ostrack(cfg): | |||||
| if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce': | |||||
| backbone = vit_base_patch16_224_ce( | |||||
| False, | |||||
| drop_path_rate=cfg.MODEL.BACKBONE.DROP_PATH_RATE, | |||||
| ce_loc=cfg.MODEL.BACKBONE.CE_LOC, | |||||
| ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO, | |||||
| ) | |||||
| hidden_dim = backbone.embed_dim | |||||
| patch_start_index = 1 | |||||
| else: | |||||
| raise NotImplementedError | |||||
| backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) | |||||
| box_head = build_box_head(cfg, hidden_dim) | |||||
| model = OSTrack( | |||||
| backbone, | |||||
| box_head, | |||||
| aux_loss=False, | |||||
| head_type=cfg.MODEL.HEAD.TYPE, | |||||
| ) | |||||
| return model | |||||
| @@ -0,0 +1,24 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch | |||||
| def combine_tokens(template_tokens, | |||||
| search_tokens, | |||||
| mode='direct', | |||||
| return_res=False): | |||||
| if mode == 'direct': | |||||
| merged_feature = torch.cat((template_tokens, search_tokens), dim=1) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return merged_feature | |||||
| def recover_tokens(merged_tokens, mode='direct'): | |||||
| if mode == 'direct': | |||||
| recovered_tokens = merged_tokens | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return recovered_tokens | |||||
| @@ -0,0 +1,343 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| from functools import partial | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from timm.models.layers import DropPath, Mlp, to_2tuple | |||||
| from modelscope.models.cv.video_single_object_tracking.models.layers.attn_blocks import \ | |||||
| CEBlock | |||||
| from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \ | |||||
| PatchEmbed | |||||
| from .base_backbone import BaseBackbone | |||||
| from .utils import combine_tokens, recover_tokens | |||||
| class Attention(nn.Module): | |||||
| def __init__(self, | |||||
| dim, | |||||
| num_heads=8, | |||||
| qkv_bias=False, | |||||
| attn_drop=0., | |||||
| proj_drop=0.): | |||||
| super().__init__() | |||||
| self.num_heads = num_heads | |||||
| head_dim = dim // num_heads | |||||
| self.scale = head_dim**-0.5 | |||||
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||||
| self.attn_drop = nn.Dropout(attn_drop) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| self.proj_drop = nn.Dropout(proj_drop) | |||||
| class Block(nn.Module): | |||||
| def __init__(self, | |||||
| dim, | |||||
| num_heads, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=False, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| act_layer=nn.GELU, | |||||
| norm_layer=nn.LayerNorm): | |||||
| super().__init__() | |||||
| self.norm1 = norm_layer(dim) | |||||
| self.attn = Attention( | |||||
| dim, | |||||
| num_heads=num_heads, | |||||
| qkv_bias=qkv_bias, | |||||
| attn_drop=attn_drop, | |||||
| proj_drop=drop) | |||||
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| self.norm2 = norm_layer(dim) | |||||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||||
| self.mlp = Mlp( | |||||
| in_features=dim, | |||||
| hidden_features=mlp_hidden_dim, | |||||
| act_layer=act_layer, | |||||
| drop=drop) | |||||
| class VisionTransformer(BaseBackbone): | |||||
| """ Vision Transformer | |||||
| A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` | |||||
| - https://arxiv.org/abs/2010.11929 | |||||
| Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` | |||||
| - https://arxiv.org/abs/2012.12877 | |||||
| """ | |||||
| def __init__(self, | |||||
| img_size=224, | |||||
| patch_size=16, | |||||
| in_chans=3, | |||||
| num_classes=1000, | |||||
| embed_dim=768, | |||||
| depth=12, | |||||
| num_heads=12, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| distilled=False, | |||||
| drop_rate=0., | |||||
| attn_drop_rate=0., | |||||
| drop_path_rate=0., | |||||
| embed_layer=PatchEmbed, | |||||
| norm_layer=None, | |||||
| act_layer=None): | |||||
| """ | |||||
| Args: | |||||
| img_size (int, tuple): input image size | |||||
| patch_size (int, tuple): patch size | |||||
| in_chans (int): number of input channels | |||||
| num_classes (int): number of classes for classification head | |||||
| embed_dim (int): embedding dimension | |||||
| depth (int): depth of transformer | |||||
| num_heads (int): number of attention heads | |||||
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim | |||||
| qkv_bias (bool): enable bias for qkv if True | |||||
| distilled (bool): model includes a distillation token and head as in DeiT models | |||||
| drop_rate (float): dropout rate | |||||
| attn_drop_rate (float): attention dropout rate | |||||
| drop_path_rate (float): stochastic depth rate | |||||
| embed_layer (nn.Module): patch embedding layer | |||||
| norm_layer: (nn.Module): normalization layer | |||||
| """ | |||||
| super().__init__() | |||||
| self.num_classes = num_classes | |||||
| self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||||
| self.num_tokens = 2 if distilled else 1 | |||||
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |||||
| act_layer = act_layer or nn.GELU | |||||
| self.patch_embed = embed_layer( | |||||
| img_size=img_size, | |||||
| patch_size=patch_size, | |||||
| in_chans=in_chans, | |||||
| embed_dim=embed_dim) | |||||
| num_patches = self.patch_embed.num_patches | |||||
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |||||
| self.dist_token = None | |||||
| self.pos_embed = nn.Parameter( | |||||
| torch.zeros(1, num_patches + self.num_tokens, embed_dim)) | |||||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |||||
| ] # stochastic depth decay rule | |||||
| self.blocks = nn.Sequential(*[ | |||||
| Block( | |||||
| dim=embed_dim, | |||||
| num_heads=num_heads, | |||||
| mlp_ratio=mlp_ratio, | |||||
| qkv_bias=qkv_bias, | |||||
| drop=drop_rate, | |||||
| attn_drop=attn_drop_rate, | |||||
| drop_path=dpr[i], | |||||
| norm_layer=norm_layer, | |||||
| act_layer=act_layer) for i in range(depth) | |||||
| ]) | |||||
| self.norm = norm_layer(embed_dim) | |||||
| class VisionTransformerCE(VisionTransformer): | |||||
| """ Vision Transformer with candidate elimination (CE) module | |||||
| A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` | |||||
| - https://arxiv.org/abs/2010.11929 | |||||
| Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` | |||||
| - https://arxiv.org/abs/2012.12877 | |||||
| """ | |||||
| def __init__(self, | |||||
| img_size=224, | |||||
| patch_size=16, | |||||
| in_chans=3, | |||||
| num_classes=1000, | |||||
| embed_dim=768, | |||||
| depth=12, | |||||
| num_heads=12, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| distilled=False, | |||||
| drop_rate=0., | |||||
| attn_drop_rate=0., | |||||
| drop_path_rate=0., | |||||
| embed_layer=PatchEmbed, | |||||
| norm_layer=None, | |||||
| act_layer=None, | |||||
| ce_loc=None, | |||||
| ce_keep_ratio=None): | |||||
| """ | |||||
| Args: | |||||
| img_size (int, tuple): input image size | |||||
| patch_size (int, tuple): patch size | |||||
| in_chans (int): number of input channels | |||||
| num_classes (int): number of classes for classification head | |||||
| embed_dim (int): embedding dimension | |||||
| depth (int): depth of transformer | |||||
| num_heads (int): number of attention heads | |||||
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim | |||||
| qkv_bias (bool): enable bias for qkv if True | |||||
| distilled (bool): model includes a distillation token and head as in DeiT models | |||||
| drop_rate (float): dropout rate | |||||
| attn_drop_rate (float): attention dropout rate | |||||
| drop_path_rate (float): stochastic depth rate | |||||
| embed_layer (nn.Module): patch embedding layer | |||||
| norm_layer: (nn.Module): normalization layer | |||||
| """ | |||||
| super().__init__() | |||||
| if isinstance(img_size, tuple): | |||||
| self.img_size = img_size | |||||
| else: | |||||
| self.img_size = to_2tuple(img_size) | |||||
| self.patch_size = patch_size | |||||
| self.in_chans = in_chans | |||||
| self.num_classes = num_classes | |||||
| self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||||
| self.num_tokens = 2 if distilled else 1 | |||||
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |||||
| act_layer = act_layer or nn.GELU | |||||
| self.patch_embed = embed_layer( | |||||
| img_size=img_size, | |||||
| patch_size=patch_size, | |||||
| in_chans=in_chans, | |||||
| embed_dim=embed_dim) | |||||
| num_patches = self.patch_embed.num_patches | |||||
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |||||
| self.dist_token = nn.Parameter(torch.zeros( | |||||
| 1, 1, embed_dim)) if distilled else None | |||||
| self.pos_embed = nn.Parameter( | |||||
| torch.zeros(1, num_patches + self.num_tokens, embed_dim)) | |||||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |||||
| ] # stochastic depth decay rule | |||||
| blocks = [] | |||||
| ce_index = 0 | |||||
| self.ce_loc = ce_loc | |||||
| for i in range(depth): | |||||
| ce_keep_ratio_i = 1.0 | |||||
| if ce_loc is not None and i in ce_loc: | |||||
| ce_keep_ratio_i = ce_keep_ratio[ce_index] | |||||
| ce_index += 1 | |||||
| blocks.append( | |||||
| CEBlock( | |||||
| dim=embed_dim, | |||||
| num_heads=num_heads, | |||||
| mlp_ratio=mlp_ratio, | |||||
| qkv_bias=qkv_bias, | |||||
| drop=drop_rate, | |||||
| attn_drop=attn_drop_rate, | |||||
| drop_path=dpr[i], | |||||
| norm_layer=norm_layer, | |||||
| act_layer=act_layer, | |||||
| keep_ratio_search=ce_keep_ratio_i)) | |||||
| self.blocks = nn.Sequential(*blocks) | |||||
| self.norm = norm_layer(embed_dim) | |||||
| def forward_features( | |||||
| self, | |||||
| z, | |||||
| x, | |||||
| mask_x=None, | |||||
| ce_template_mask=None, | |||||
| ce_keep_rate=None, | |||||
| ): | |||||
| B = x.shape[0] | |||||
| x = self.patch_embed(x) | |||||
| z = self.patch_embed(z) | |||||
| z += self.pos_embed_z | |||||
| x += self.pos_embed_x | |||||
| x = combine_tokens(z, x, mode=self.cat_mode) | |||||
| x = self.pos_drop(x) | |||||
| lens_z = self.pos_embed_z.shape[1] | |||||
| lens_x = self.pos_embed_x.shape[1] | |||||
| global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) | |||||
| global_index_t = global_index_t.repeat(B, 1) | |||||
| global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) | |||||
| global_index_s = global_index_s.repeat(B, 1) | |||||
| removed_indexes_s = [] | |||||
| for i, blk in enumerate(self.blocks): | |||||
| x, global_index_t, global_index_s, removed_index_s, attn = \ | |||||
| blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) | |||||
| if self.ce_loc is not None and i in self.ce_loc: | |||||
| removed_indexes_s.append(removed_index_s) | |||||
| x = self.norm(x) | |||||
| lens_x_new = global_index_s.shape[1] | |||||
| lens_z_new = global_index_t.shape[1] | |||||
| z = x[:, :lens_z_new] | |||||
| x = x[:, lens_z_new:] | |||||
| if removed_indexes_s and removed_indexes_s[0] is not None: | |||||
| removed_indexes_cat = torch.cat(removed_indexes_s, dim=1) | |||||
| pruned_lens_x = lens_x - lens_x_new | |||||
| pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]], | |||||
| device=x.device) | |||||
| x = torch.cat([x, pad_x], dim=1) | |||||
| index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1) | |||||
| # recover original token order | |||||
| C = x.shape[-1] | |||||
| x = torch.zeros_like(x).scatter_( | |||||
| dim=1, | |||||
| index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64), | |||||
| src=x) | |||||
| x = recover_tokens(x, mode=self.cat_mode) | |||||
| # re-concatenate with the template, which may be further used by other modules | |||||
| x = torch.cat([z, x], dim=1) | |||||
| aux_dict = { | |||||
| 'attn': attn, | |||||
| 'removed_indexes_s': removed_indexes_s, # used for visualization | |||||
| } | |||||
| return x, aux_dict | |||||
| def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None): | |||||
| x, aux_dict = self.forward_features( | |||||
| z, | |||||
| x, | |||||
| ce_template_mask=ce_template_mask, | |||||
| ce_keep_rate=ce_keep_rate, | |||||
| ) | |||||
| return x, aux_dict | |||||
| def _create_vision_transformer(pretrained=False, **kwargs): | |||||
| model = VisionTransformerCE(**kwargs) | |||||
| return model | |||||
| def vit_base_patch16_224_ce(pretrained=False, **kwargs): | |||||
| """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). | |||||
| """ | |||||
| model_kwargs = dict( | |||||
| patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) | |||||
| model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) | |||||
| return model | |||||
| @@ -0,0 +1,139 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import torch | |||||
| from modelscope.models.cv.video_single_object_tracking.config.ostrack import \ | |||||
| cfg | |||||
| from modelscope.models.cv.video_single_object_tracking.models.ostrack.ostrack import \ | |||||
| build_ostrack | |||||
| from modelscope.models.cv.video_single_object_tracking.utils.utils import ( | |||||
| Preprocessor, clip_box, generate_mask_cond, hann2d, sample_target, | |||||
| transform_image_to_crop) | |||||
| class OSTrack(): | |||||
| def __init__(self, ckpt_path, device): | |||||
| network = build_ostrack(cfg) | |||||
| network.load_state_dict( | |||||
| torch.load(ckpt_path, map_location='cpu')['net'], strict=True) | |||||
| self.cfg = cfg | |||||
| if device.type == 'cuda': | |||||
| self.network = network.to(device) | |||||
| else: | |||||
| self.network = network | |||||
| self.network.eval() | |||||
| self.preprocessor = Preprocessor(device) | |||||
| self.state = None | |||||
| self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE | |||||
| # motion constrain | |||||
| if device.type == 'cuda': | |||||
| self.output_window = hann2d( | |||||
| torch.tensor([self.feat_sz, self.feat_sz]).long(), | |||||
| centered=True).to(device) | |||||
| else: | |||||
| self.output_window = hann2d( | |||||
| torch.tensor([self.feat_sz, self.feat_sz]).long(), | |||||
| centered=True) | |||||
| self.frame_id = 0 | |||||
| # for save boxes from all queries | |||||
| self.z_dict1 = {} | |||||
| def initialize(self, image, info: dict): | |||||
| # forward the template once | |||||
| z_patch_arr, resize_factor, z_amask_arr = sample_target( | |||||
| image, | |||||
| info['init_bbox'], | |||||
| self.cfg.TEST.TEMPLATE_FACTOR, | |||||
| output_sz=self.cfg.TEST.TEMPLATE_SIZE) | |||||
| self.z_patch_arr = z_patch_arr | |||||
| template = self.preprocessor.process(z_patch_arr, z_amask_arr) | |||||
| with torch.no_grad(): | |||||
| self.z_dict1 = template | |||||
| self.box_mask_z = None | |||||
| if self.cfg.MODEL.BACKBONE.CE_LOC: | |||||
| template_bbox = self.transform_bbox_to_crop( | |||||
| info['init_bbox'], resize_factor, | |||||
| template.tensors.device).squeeze(1) | |||||
| self.box_mask_z = generate_mask_cond(self.cfg, 1, | |||||
| template.tensors.device, | |||||
| template_bbox) | |||||
| # save states | |||||
| self.state = info['init_bbox'] | |||||
| self.frame_id = 0 | |||||
| def track(self, image, info: dict = None): | |||||
| H, W, _ = image.shape | |||||
| self.frame_id += 1 | |||||
| x_patch_arr, resize_factor, x_amask_arr = sample_target( | |||||
| image, | |||||
| self.state, | |||||
| self.cfg.TEST.SEARCH_FACTOR, | |||||
| output_sz=self.cfg.TEST.SEARCH_SIZE) # (x1, y1, w, h) | |||||
| search = self.preprocessor.process(x_patch_arr, x_amask_arr) | |||||
| with torch.no_grad(): | |||||
| x_dict = search | |||||
| # merge the template and the search | |||||
| # run the transformer | |||||
| out_dict = self.network.forward( | |||||
| template=self.z_dict1.tensors, | |||||
| search=x_dict.tensors, | |||||
| ce_template_mask=self.box_mask_z) | |||||
| # add hann windows | |||||
| pred_score_map = out_dict['score_map'] | |||||
| response = self.output_window * pred_score_map | |||||
| pred_boxes = self.network.box_head.cal_bbox(response, | |||||
| out_dict['size_map'], | |||||
| out_dict['offset_map']) | |||||
| pred_boxes = pred_boxes.view(-1, 4) | |||||
| # Baseline: Take the mean of all pred boxes as the final result | |||||
| pred_box = (pred_boxes.mean(dim=0) * self.cfg.TEST.SEARCH_SIZE | |||||
| / resize_factor).tolist() # (cx, cy, w, h) [0,1] | |||||
| # get the final box result | |||||
| self.state = clip_box( | |||||
| self.map_box_back(pred_box, resize_factor), H, W, margin=10) | |||||
| x1, y1, w, h = self.state | |||||
| x2 = x1 + w | |||||
| y2 = y1 + h | |||||
| return {'target_bbox': [x1, y1, x2, y2]} | |||||
| def map_box_back(self, pred_box: list, resize_factor: float): | |||||
| cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[ | |||||
| 1] + 0.5 * self.state[3] | |||||
| cx, cy, w, h = pred_box | |||||
| half_side = 0.5 * self.cfg.TEST.SEARCH_SIZE / resize_factor | |||||
| cx_real = cx + (cx_prev - half_side) | |||||
| cy_real = cy + (cy_prev - half_side) | |||||
| return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] | |||||
| def transform_bbox_to_crop(self, | |||||
| box_in, | |||||
| resize_factor, | |||||
| device, | |||||
| box_extract=None, | |||||
| crop_type='template'): | |||||
| if crop_type == 'template': | |||||
| crop_sz = torch.Tensor( | |||||
| [self.cfg.TEST.TEMPLATE_SIZE, self.cfg.TEST.TEMPLATE_SIZE]) | |||||
| elif crop_type == 'search': | |||||
| crop_sz = torch.Tensor( | |||||
| [self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE]) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| box_in = torch.tensor(box_in) | |||||
| if box_extract is None: | |||||
| box_extract = box_in | |||||
| else: | |||||
| box_extract = torch.tensor(box_extract) | |||||
| template_bbox = transform_image_to_crop( | |||||
| box_in, box_extract, resize_factor, crop_sz, normalize=True) | |||||
| template_bbox = template_bbox.view(1, 1, 4).to(device) | |||||
| return template_bbox | |||||
| @@ -0,0 +1,261 @@ | |||||
| # The implementation is also open-sourced by the authors as OSTrack, and is available publicly on | |||||
| # https://github.com/botaoye/OSTrack/ | |||||
| import math | |||||
| from typing import Optional | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import Tensor | |||||
| def hann1d(sz: int, centered=True) -> torch.Tensor: | |||||
| """1D cosine window.""" | |||||
| if centered: | |||||
| return 0.5 * (1 - torch.cos( | |||||
| (2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float())) | |||||
| w = 0.5 * (1 + torch.cos( | |||||
| (2 * math.pi / (sz + 2)) * torch.arange(0, sz // 2 + 1).float())) | |||||
| return torch.cat([w, w[1:sz - sz // 2].flip((0, ))]) | |||||
| def hann2d(sz: torch.Tensor, centered=True) -> torch.Tensor: | |||||
| """2D cosine window.""" | |||||
| return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d( | |||||
| sz[1].item(), centered).reshape(1, 1, 1, -1) | |||||
| class NestedTensor(object): | |||||
| def __init__(self, tensors, mask: Optional[Tensor]): | |||||
| self.tensors = tensors | |||||
| self.mask = mask | |||||
| class Preprocessor(object): | |||||
| def __init__(self, device: str): | |||||
| self.device = device | |||||
| self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)) | |||||
| self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)) | |||||
| if 'cuda' == self.device.type: | |||||
| self.mean = self.mean.to(self.device) | |||||
| self.std = self.std.to(self.device) | |||||
| def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): | |||||
| # Deal with the image patch | |||||
| if 'cuda' == self.device.type: | |||||
| img_tensor = torch.tensor(img_arr).to(self.device).float().permute( | |||||
| (2, 0, 1)).unsqueeze(dim=0) | |||||
| else: | |||||
| img_tensor = torch.tensor(img_arr).float().permute( | |||||
| (2, 0, 1)).unsqueeze(dim=0) | |||||
| img_tensor_norm = ( | |||||
| (img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) | |||||
| # Deal with the attention mask | |||||
| if 'cuda' == self.device.type: | |||||
| amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).to( | |||||
| self.device).unsqueeze(dim=0) # (1,H,W) | |||||
| else: | |||||
| amask_tensor = torch.from_numpy(amask_arr).to( | |||||
| torch.bool).unsqueeze(dim=0) # (1,H,W) | |||||
| return NestedTensor(img_tensor_norm, amask_tensor) | |||||
| def clip_box(box: list, H, W, margin=0): | |||||
| x1, y1, w, h = box | |||||
| x2, y2 = x1 + w, y1 + h | |||||
| x1 = min(max(0, x1), W - margin) | |||||
| x2 = min(max(margin, x2), W) | |||||
| y1 = min(max(0, y1), H - margin) | |||||
| y2 = min(max(margin, y2), H) | |||||
| w = max(margin, x2 - x1) | |||||
| h = max(margin, y2 - y1) | |||||
| if isinstance(x1, torch.Tensor): | |||||
| x1 = x1.item() | |||||
| y1 = y1.item() | |||||
| w = w.item() | |||||
| h = h.item() | |||||
| return [x1, y1, w, h] | |||||
| def generate_mask_cond(cfg, bs, device, gt_bbox): | |||||
| template_size = cfg.DATA.TEMPLATE.SIZE | |||||
| stride = cfg.MODEL.BACKBONE.STRIDE | |||||
| template_feat_size = template_size // stride | |||||
| if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': | |||||
| if template_feat_size == 8: | |||||
| index = slice(3, 4) | |||||
| elif template_feat_size == 12: | |||||
| index = slice(5, 6) | |||||
| elif template_feat_size == 7: | |||||
| index = slice(3, 4) | |||||
| elif template_feat_size == 14: | |||||
| index = slice(6, 7) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], | |||||
| device=device) | |||||
| box_mask_z[:, index, index] = 1 | |||||
| box_mask_z = box_mask_z.flatten(1).to(torch.bool) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return box_mask_z | |||||
| def sample_target(im, | |||||
| target_bb, | |||||
| search_area_factor, | |||||
| output_sz=None, | |||||
| mask=None): | |||||
| """ Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area | |||||
| args: | |||||
| im - cv image | |||||
| target_bb - target box [x, y, w, h] | |||||
| search_area_factor - Ratio of crop size to target size | |||||
| output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done. | |||||
| returns: | |||||
| cv image - extracted crop | |||||
| float - the factor by which the crop has been resized to make the crop size equal output_size | |||||
| """ | |||||
| if not isinstance(target_bb, list): | |||||
| x, y, w, h = target_bb.tolist() | |||||
| else: | |||||
| x, y, w, h = target_bb | |||||
| # Crop image | |||||
| crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor) | |||||
| if crop_sz < 1: | |||||
| raise Exception('Too small bounding box.') | |||||
| x1 = round(x + 0.5 * w - crop_sz * 0.5) | |||||
| x2 = x1 + crop_sz | |||||
| y1 = round(y + 0.5 * h - crop_sz * 0.5) | |||||
| y2 = y1 + crop_sz | |||||
| x1_pad = max(0, -x1) | |||||
| x2_pad = max(x2 - im.shape[1] + 1, 0) | |||||
| y1_pad = max(0, -y1) | |||||
| y2_pad = max(y2 - im.shape[0] + 1, 0) | |||||
| # Crop target | |||||
| im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :] | |||||
| if mask is not None: | |||||
| mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad] | |||||
| # Pad | |||||
| im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, | |||||
| x2_pad, cv2.BORDER_CONSTANT) | |||||
| # deal with attention mask | |||||
| H, W, _ = im_crop_padded.shape | |||||
| att_mask = np.ones((H, W)) | |||||
| end_x, end_y = -x2_pad, -y2_pad | |||||
| if y2_pad == 0: | |||||
| end_y = None | |||||
| if x2_pad == 0: | |||||
| end_x = None | |||||
| att_mask[y1_pad:end_y, x1_pad:end_x] = 0 | |||||
| if mask is not None: | |||||
| mask_crop_padded = F.pad( | |||||
| mask_crop, | |||||
| pad=(x1_pad, x2_pad, y1_pad, y2_pad), | |||||
| mode='constant', | |||||
| value=0) | |||||
| if output_sz is not None: | |||||
| resize_factor = output_sz / crop_sz | |||||
| im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz)) | |||||
| att_mask = cv2.resize(att_mask, | |||||
| (output_sz, output_sz)).astype(np.bool_) | |||||
| if mask is None: | |||||
| return im_crop_padded, resize_factor, att_mask | |||||
| mask_crop_padded = \ | |||||
| F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), | |||||
| mode='bilinear', align_corners=False)[0, 0] | |||||
| return im_crop_padded, resize_factor, att_mask, mask_crop_padded | |||||
| else: | |||||
| if mask is None: | |||||
| return im_crop_padded, att_mask.astype(np.bool_), 1.0 | |||||
| return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded | |||||
| def transform_image_to_crop(box_in: torch.Tensor, | |||||
| box_extract: torch.Tensor, | |||||
| resize_factor: float, | |||||
| crop_sz: torch.Tensor, | |||||
| normalize=False) -> torch.Tensor: | |||||
| """ Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image | |||||
| args: | |||||
| box_in - the box for which the co-ordinates are to be transformed | |||||
| box_extract - the box about which the image crop has been extracted. | |||||
| resize_factor - the ratio between the original image scale and the scale of the image crop | |||||
| crop_sz - size of the cropped image | |||||
| returns: | |||||
| torch.Tensor - transformed co-ordinates of box_in | |||||
| """ | |||||
| box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4] | |||||
| box_in_center = box_in[0:2] + 0.5 * box_in[2:4] | |||||
| box_out_center = (crop_sz - 1) / 2 + (box_in_center | |||||
| - box_extract_center) * resize_factor | |||||
| box_out_wh = box_in[2:4] * resize_factor | |||||
| box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh)) | |||||
| if normalize: | |||||
| return box_out / crop_sz[0] | |||||
| else: | |||||
| return box_out | |||||
| def check_box(box: list, image_height, image_width) -> bool: | |||||
| """ To check whether the box is within the image range or not | |||||
| args: | |||||
| box - the bounding box in the form of [x1, y1, x2, y2] | |||||
| image_height - the height of the image | |||||
| image_width - the width of the image | |||||
| returns: | |||||
| bool - if box is valid, return True. Otherwise, return False | |||||
| """ | |||||
| assert len(box) == 4, 'box must be in the form of: [x1, y1, x2, y2]' | |||||
| if box[0] < 0 or box[0] >= image_width: | |||||
| return False | |||||
| if box[2] < 0 or box[2] >= image_width: | |||||
| return False | |||||
| if box[1] < 0 or box[1] >= image_height: | |||||
| return False | |||||
| if box[3] < 0 or box[3] >= image_height: | |||||
| return False | |||||
| return True | |||||
| def show_tracking_result(video_in_path, bboxes, video_save_path): | |||||
| cap = cv2.VideoCapture(video_in_path) | |||||
| for i in range(len(bboxes)): | |||||
| box = bboxes[i] | |||||
| success, frame = cap.read() | |||||
| if success is False: | |||||
| raise Exception(video_in_path, | |||||
| ' can not be correctly decoded by OpenCV.') | |||||
| if i == 0: | |||||
| size = (frame.shape[1], frame.shape[0]) | |||||
| fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') | |||||
| video_writer = cv2.VideoWriter(video_save_path, fourcc, | |||||
| cap.get(cv2.CAP_PROP_FPS), size, | |||||
| True) | |||||
| cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), | |||||
| 5) | |||||
| video_writer.write(frame) | |||||
| video_writer.release | |||||
| cap.release() | |||||
| @@ -188,6 +188,16 @@ TASK_OUTPUTS = { | |||||
| Tasks.body_2d_keypoints: | Tasks.body_2d_keypoints: | ||||
| [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], | [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], | ||||
| # video single object tracking result for single video | |||||
| # { | |||||
| # "boxes": [ | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # ] | |||||
| # } | |||||
| Tasks.video_single_object_tracking: [OutputKeys.BOXES], | |||||
| # live category recognition result for single video | # live category recognition result for single video | ||||
| # { | # { | ||||
| # "scores": [0.885272, 0.014790631, 0.014558001], | # "scores": [0.885272, 0.014790631, 0.014558001], | ||||
| @@ -130,6 +130,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_unet_skin-retouching'), | 'damo/cv_unet_skin-retouching'), | ||||
| Tasks.crowd_counting: (Pipelines.crowd_counting, | Tasks.crowd_counting: (Pipelines.crowd_counting, | ||||
| 'damo/cv_hrnet_crowd-counting_dcanet'), | 'damo/cv_hrnet_crowd-counting_dcanet'), | ||||
| Tasks.video_single_object_tracking: | |||||
| (Pipelines.video_single_object_tracking, | |||||
| 'damo/cv_vitb_video-single-object-tracking_ostrack'), | |||||
| } | } | ||||
| @@ -0,0 +1,80 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.video_single_object_tracking.config.ostrack import \ | |||||
| cfg | |||||
| from modelscope.models.cv.video_single_object_tracking.tracker.ostrack import \ | |||||
| OSTrack | |||||
| from modelscope.models.cv.video_single_object_tracking.utils.utils import \ | |||||
| check_box | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.video_single_object_tracking, | |||||
| module_name=Pipelines.video_single_object_tracking) | |||||
| class VideoSingleObjectTrackingPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a single object tracking pipeline | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| self.cfg = cfg | |||||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| self.tracker = OSTrack(ckpt_path, self.device) | |||||
| logger.info('init tracker done') | |||||
| def preprocess(self, input) -> Input: | |||||
| self.video_path = input[0] | |||||
| self.init_bbox = input[1] | |||||
| return input | |||||
| def forward(self, input: Input) -> Dict[str, Any]: | |||||
| output_boxes = [] | |||||
| cap = cv2.VideoCapture(self.video_path) | |||||
| success, frame = cap.read() | |||||
| if success is False: | |||||
| raise Exception( | |||||
| 'modelscope error: %s can not be decoded by OpenCV.' % | |||||
| (self.video_path)) | |||||
| init_box = self.init_bbox | |||||
| frame_h, frame_w = frame.shape[0:2] | |||||
| if not check_box(init_box, frame_h, frame_w): | |||||
| raise Exception('modelscope error: init_box out of image range ', | |||||
| init_box) | |||||
| output_boxes.append(init_box.copy()) | |||||
| init_box[2] = init_box[2] - init_box[0] | |||||
| init_box[3] = init_box[3] - init_box[1] | |||||
| self.tracker.initialize(frame, {'init_bbox': init_box}) | |||||
| logger.info('init bbox done') | |||||
| while True: | |||||
| ret, frame = cap.read() | |||||
| if frame is None: | |||||
| break | |||||
| out = self.tracker.track(frame) | |||||
| state = [int(s) for s in out['target_bbox']] | |||||
| output_boxes.append(state) | |||||
| cap.release() | |||||
| logger.info('tracking process done') | |||||
| return { | |||||
| OutputKeys.BOXES: output_boxes, | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -62,6 +62,9 @@ class CVTasks(object): | |||||
| virtual_try_on = 'virtual-try-on' | virtual_try_on = 'virtual-try-on' | ||||
| crowd_counting = 'crowd-counting' | crowd_counting = 'crowd-counting' | ||||
| # video related | |||||
| video_single_object_tracking = 'video-single-object-tracking' | |||||
| class NLPTasks(object): | class NLPTasks(object): | ||||
| # nlp tasks | # nlp tasks | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.models.cv.video_single_object_tracking.utils.utils import \ | |||||
| show_tracking_result | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class SingleObjectTracking(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_vitb_video-single-object-tracking_ostrack' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_end2end(self): | |||||
| video_single_object_tracking = pipeline( | |||||
| Tasks.video_single_object_tracking, model=self.model_id) | |||||
| video_path = 'data/test/videos/dog.avi' | |||||
| init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2] | |||||
| result = video_single_object_tracking((video_path, init_bbox)) | |||||
| print('result is : ', result[OutputKeys.BOXES]) | |||||
| show_tracking_result(video_path, result[OutputKeys.BOXES], | |||||
| './tracking_result.avi') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| video_single_object_tracking = pipeline( | |||||
| Tasks.video_single_object_tracking) | |||||
| video_path = 'data/test/videos/dog.avi' | |||||
| init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2] | |||||
| result = video_single_object_tracking((video_path, init_bbox)) | |||||
| print('result is : ', result[OutputKeys.BOXES]) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||