Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10400324master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb | |||
size 2924691 |
@@ -34,6 +34,7 @@ class Models(object): | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
resnet50_bert = 'resnet50-bert' | |||
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | |||
fer = 'fer' | |||
retinaface = 'retinaface' | |||
shop_segmentation = 'shop-segmentation' | |||
@@ -203,6 +204,7 @@ class Pipelines(object): | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
image_body_reshaping = 'flow-based-body-reshaping' | |||
referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
# nlp tasks | |||
automatic_post_editing = 'automatic-post-editing' | |||
@@ -12,7 +12,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
image_to_image_generation, image_to_image_translation, | |||
movie_scene_segmentation, object_detection, | |||
product_retrieval_embedding, realtime_object_detection, | |||
salient_detection, shop_segmentation, super_resolution, | |||
referring_video_object_segmentation, salient_detection, | |||
shop_segmentation, super_resolution, | |||
video_single_object_tracking, video_summarization, virual_tryon) | |||
# yapf: enable |
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .model import MovieSceneSegmentation | |||
else: | |||
_import_structure = { | |||
'model': ['MovieSceneSegmentation'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,65 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, | |||
nested_tensor_from_videos_list) | |||
logger = get_logger() | |||
@MODELS.register_module( | |||
Tasks.referring_video_object_segmentation, | |||
module_name=Models.referring_video_object_segmentation) | |||
class ReferringVideoObjectSegmentation(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""str -- model file root.""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
config_path = osp.join(model_dir, ModelFile.CONFIGURATION) | |||
self.cfg = Config.from_file(config_path) | |||
self.model = MTTR(**self.cfg.model) | |||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
params_dict = torch.load(model_path, map_location='cpu') | |||
if 'model_state_dict' in params_dict.keys(): | |||
params_dict = params_dict['model_state_dict'] | |||
self.model.load_state_dict(params_dict, strict=True) | |||
dataset_name = self.cfg.pipeline.dataset_name | |||
if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': | |||
self.postprocessor = A2DSentencesPostProcess() | |||
elif dataset_name == 'ref_youtube_vos': | |||
self.postprocessor = ReferYoutubeVOSPostProcess() | |||
else: | |||
assert False, f'postprocessing for dataset: {dataset_name} is not supported' | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: | |||
return inputs | |||
def inference(self, **kwargs): | |||
window = kwargs['window'] | |||
text_query = kwargs['text_query'] | |||
video_metadata = kwargs['metadata'] | |||
window = nested_tensor_from_videos_list([window]) | |||
valid_indices = torch.arange(len(window.tensors)) | |||
if self._device_name == 'gpu': | |||
valid_indices = valid_indices.cuda() | |||
outputs = self.model(window, valid_indices, [text_query]) | |||
window_masks = self.postprocessor( | |||
outputs, [video_metadata], | |||
window.tensors.shape[-2:])[0]['pred_masks'] | |||
return window_masks | |||
def postprocess(self, inputs: Dict[str, Any], **kwargs): | |||
return inputs |
@@ -0,0 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .misc import nested_tensor_from_videos_list | |||
from .mttr import MTTR | |||
from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess |
@@ -0,0 +1,198 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
import torch | |||
import torch.nn.functional as F | |||
import torchvision | |||
from einops import rearrange | |||
from torch import nn | |||
from torchvision.models._utils import IntermediateLayerGetter | |||
from .misc import NestedTensor, is_main_process | |||
from .swin_transformer import SwinTransformer3D | |||
class VideoSwinTransformerBackbone(nn.Module): | |||
""" | |||
A wrapper which allows using Video-Swin Transformer as a temporal encoder for MTTR. | |||
Check out video-swin's original paper at: https://arxiv.org/abs/2106.13230 for more info about this architecture. | |||
Only the 'tiny' version of video swin was tested and is currently supported in our project. | |||
Additionally, we slightly modify video-swin to make it output per-frame embeddings as required by MTTR (check our | |||
paper's supplementary for more details), and completely discard of its 4th block. | |||
""" | |||
def __init__(self, backbone_pretrained, backbone_pretrained_path, | |||
train_backbone, running_mode, **kwargs): | |||
super(VideoSwinTransformerBackbone, self).__init__() | |||
# patch_size is (1, 4, 4) instead of the original (2, 4, 4). | |||
# this prevents swinT's original temporal downsampling so we can get per-frame features. | |||
swin_backbone = SwinTransformer3D( | |||
patch_size=(1, 4, 4), | |||
embed_dim=96, | |||
depths=(2, 2, 6, 2), | |||
num_heads=(3, 6, 12, 24), | |||
window_size=(8, 7, 7), | |||
drop_path_rate=0.1, | |||
patch_norm=True) | |||
if backbone_pretrained and running_mode == 'train': | |||
state_dict = torch.load(backbone_pretrained_path)['state_dict'] | |||
# extract swinT's kinetics-400 pretrained weights and ignore the rest (prediction head etc.) | |||
state_dict = { | |||
k[9:]: v | |||
for k, v in state_dict.items() if 'backbone.' in k | |||
} | |||
# sum over the patch embedding weight temporal dim [96, 3, 2, 4, 4] --> [96, 3, 1, 4, 4] | |||
patch_embed_weight = state_dict['patch_embed.proj.weight'] | |||
patch_embed_weight = patch_embed_weight.sum(dim=2, keepdims=True) | |||
state_dict['patch_embed.proj.weight'] = patch_embed_weight | |||
swin_backbone.load_state_dict(state_dict) | |||
self.patch_embed = swin_backbone.patch_embed | |||
self.pos_drop = swin_backbone.pos_drop | |||
self.layers = swin_backbone.layers[:-1] | |||
self.downsamples = nn.ModuleList() | |||
for layer in self.layers: | |||
self.downsamples.append(layer.downsample) | |||
layer.downsample = None | |||
self.downsamples[ | |||
-1] = None # downsampling after the last layer is not necessary | |||
self.layer_output_channels = [ | |||
swin_backbone.embed_dim * 2**i for i in range(len(self.layers)) | |||
] | |||
self.train_backbone = train_backbone | |||
if not train_backbone: | |||
for parameter in self.parameters(): | |||
parameter.requires_grad_(False) | |||
def forward(self, samples: NestedTensor): | |||
vid_frames = rearrange(samples.tensors, 't b c h w -> b c t h w') | |||
vid_embeds = self.patch_embed(vid_frames) | |||
vid_embeds = self.pos_drop(vid_embeds) | |||
layer_outputs = [] # layer outputs before downsampling | |||
for layer, downsample in zip(self.layers, self.downsamples): | |||
vid_embeds = layer(vid_embeds.contiguous()) | |||
layer_outputs.append(vid_embeds) | |||
if downsample: | |||
vid_embeds = rearrange(vid_embeds, 'b c t h w -> b t h w c') | |||
vid_embeds = downsample(vid_embeds) | |||
vid_embeds = rearrange(vid_embeds, 'b t h w c -> b c t h w') | |||
layer_outputs = [ | |||
rearrange(o, 'b c t h w -> t b c h w') for o in layer_outputs | |||
] | |||
outputs = [] | |||
orig_pad_mask = samples.mask | |||
for l_out in layer_outputs: | |||
pad_mask = F.interpolate( | |||
orig_pad_mask.float(), size=l_out.shape[-2:]).to(torch.bool) | |||
outputs.append(NestedTensor(l_out, pad_mask)) | |||
return outputs | |||
def num_parameters(self): | |||
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |||
class FrozenBatchNorm2d(torch.nn.Module): | |||
""" | |||
Modified from DETR https://github.com/facebookresearch/detr | |||
BatchNorm2d where the batch statistics and the affine parameters are fixed. | |||
Copy-paste from torchvision.misc.ops with added eps before rqsrt, | |||
without which any other models than torchvision.models.resnet[18,34,50,101] | |||
produce nans. | |||
""" | |||
def __init__(self, n): | |||
super(FrozenBatchNorm2d, self).__init__() | |||
self.register_buffer('weight', torch.ones(n)) | |||
self.register_buffer('bias', torch.zeros(n)) | |||
self.register_buffer('running_mean', torch.zeros(n)) | |||
self.register_buffer('running_var', torch.ones(n)) | |||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |||
missing_keys, unexpected_keys, error_msgs): | |||
num_batches_tracked_key = prefix + 'num_batches_tracked' | |||
if num_batches_tracked_key in state_dict: | |||
del state_dict[num_batches_tracked_key] | |||
super(FrozenBatchNorm2d, | |||
self)._load_from_state_dict(state_dict, prefix, local_metadata, | |||
strict, missing_keys, | |||
unexpected_keys, error_msgs) | |||
def forward(self, x): | |||
# move reshapes to the beginning | |||
# to make it fuser-friendly | |||
w = self.weight.reshape(1, -1, 1, 1) | |||
b = self.bias.reshape(1, -1, 1, 1) | |||
rv = self.running_var.reshape(1, -1, 1, 1) | |||
rm = self.running_mean.reshape(1, -1, 1, 1) | |||
eps = 1e-5 | |||
scale = w * (rv + eps).rsqrt() | |||
bias = b - rm * scale | |||
return x * scale + bias | |||
class ResNetBackbone(nn.Module): | |||
""" | |||
Modified from DETR https://github.com/facebookresearch/detr | |||
ResNet backbone with frozen BatchNorm. | |||
""" | |||
def __init__(self, | |||
backbone_name: str = 'resnet50', | |||
train_backbone: bool = True, | |||
dilation: bool = True, | |||
**kwargs): | |||
super(ResNetBackbone, self).__init__() | |||
backbone = getattr(torchvision.models, backbone_name)( | |||
replace_stride_with_dilation=[False, False, dilation], | |||
pretrained=is_main_process(), | |||
norm_layer=FrozenBatchNorm2d) | |||
for name, parameter in backbone.named_parameters(): | |||
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: | |||
parameter.requires_grad_(False) | |||
return_layers = { | |||
'layer1': '0', | |||
'layer2': '1', | |||
'layer3': '2', | |||
'layer4': '3' | |||
} | |||
self.body = IntermediateLayerGetter( | |||
backbone, return_layers=return_layers) | |||
output_channels = 512 if backbone_name in ('resnet18', | |||
'resnet34') else 2048 | |||
self.layer_output_channels = [ | |||
output_channels // 8, output_channels // 4, output_channels // 2, | |||
output_channels | |||
] | |||
def forward(self, tensor_list: NestedTensor): | |||
t, b, _, _, _ = tensor_list.tensors.shape | |||
video_frames = rearrange(tensor_list.tensors, | |||
't b c h w -> (t b) c h w') | |||
padding_masks = rearrange(tensor_list.mask, 't b h w -> (t b) h w') | |||
features_list = self.body(video_frames) | |||
out = [] | |||
for _, f in features_list.items(): | |||
resized_padding_masks = F.interpolate( | |||
padding_masks[None].float(), | |||
size=f.shape[-2:]).to(torch.bool)[0] | |||
f = rearrange(f, '(t b) c h w -> t b c h w', t=t, b=b) | |||
resized_padding_masks = rearrange( | |||
resized_padding_masks, '(t b) h w -> t b h w', t=t, b=b) | |||
out.append(NestedTensor(f, resized_padding_masks)) | |||
return out | |||
def num_parameters(self): | |||
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |||
def init_backbone(backbone_name, **kwargs): | |||
if backbone_name == 'swin-t': | |||
return VideoSwinTransformerBackbone(**kwargs) | |||
elif 'resnet' in backbone_name: | |||
return ResNetBackbone(backbone_name, **kwargs) | |||
assert False, f'error: backbone "{backbone_name}" is not supported' |
@@ -0,0 +1,234 @@ | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
# Misc functions. | |||
# Mostly copy-paste from torchvision references. | |||
import pickle | |||
from typing import List, Optional | |||
import torch | |||
import torch.distributed as dist | |||
# needed due to empty tensor bug in pytorch and torchvision 0.5 | |||
import torchvision | |||
from torch import Tensor | |||
if float(torchvision.__version__.split('.')[1]) < 7.0: | |||
from torchvision.ops import _new_empty_tensor | |||
from torchvision.ops.misc import _output_size | |||
def all_gather(data): | |||
""" | |||
Run all_gather on arbitrary picklable data (not necessarily tensors) | |||
Args: | |||
data: any picklable object | |||
Returns: | |||
list[data]: list of data gathered from each rank | |||
""" | |||
world_size = get_world_size() | |||
if world_size == 1: | |||
return [data] | |||
# serialized to a Tensor | |||
buffer = pickle.dumps(data) | |||
storage = torch.ByteStorage.from_buffer(buffer) | |||
tensor = torch.ByteTensor(storage).to('cuda') | |||
# obtain Tensor size of each rank | |||
local_size = torch.tensor([tensor.numel()], device='cuda') | |||
size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)] | |||
dist.all_gather(size_list, local_size) | |||
size_list = [int(size.item()) for size in size_list] | |||
max_size = max(size_list) | |||
# receiving Tensor from all ranks | |||
# we pad the tensor because torch all_gather does not support | |||
# gathering tensors of different shapes | |||
tensor_list = [] | |||
for _ in size_list: | |||
tensor_list.append( | |||
torch.empty((max_size, ), dtype=torch.uint8, device='cuda')) | |||
if local_size != max_size: | |||
padding = torch.empty( | |||
size=(max_size - local_size, ), dtype=torch.uint8, device='cuda') | |||
tensor = torch.cat((tensor, padding), dim=0) | |||
dist.all_gather(tensor_list, tensor) | |||
data_list = [] | |||
for size, tensor in zip(size_list, tensor_list): | |||
buffer = tensor.cpu().numpy().tobytes()[:size] | |||
data_list.append(pickle.loads(buffer)) | |||
return data_list | |||
def reduce_dict(input_dict, average=True): | |||
""" | |||
Args: | |||
input_dict (dict): all the values will be reduced | |||
average (bool): whether to do average or sum | |||
Reduce the values in the dictionary from all processes so that all processes | |||
have the averaged results. Returns a dict with the same fields as | |||
input_dict, after reduction. | |||
""" | |||
world_size = get_world_size() | |||
if world_size < 2: | |||
return input_dict | |||
with torch.no_grad(): | |||
names = [] | |||
values = [] | |||
# sort the keys so that they are consistent across processes | |||
for k in sorted(input_dict.keys()): | |||
names.append(k) | |||
values.append(input_dict[k]) | |||
values = torch.stack(values, dim=0) | |||
dist.all_reduce(values) | |||
if average: | |||
values /= world_size | |||
reduced_dict = {k: v for k, v in zip(names, values)} | |||
return reduced_dict | |||
def _max_by_axis(the_list): | |||
# type: (List[List[int]]) -> List[int] | |||
maxes = the_list[0] | |||
for sublist in the_list[1:]: | |||
for index, item in enumerate(sublist): | |||
maxes[index] = max(maxes[index], item) | |||
return maxes | |||
class NestedTensor(object): | |||
def __init__(self, tensors, mask: Optional[Tensor]): | |||
self.tensors = tensors | |||
self.mask = mask | |||
def to(self, device): | |||
# type: (Device) -> NestedTensor # noqa | |||
cast_tensor = self.tensors.to(device) | |||
mask = self.mask | |||
if mask is not None: | |||
assert mask is not None | |||
cast_mask = mask.to(device) | |||
else: | |||
cast_mask = None | |||
return NestedTensor(cast_tensor, cast_mask) | |||
def decompose(self): | |||
return self.tensors, self.mask | |||
def __repr__(self): | |||
return str(self.tensors) | |||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): | |||
""" | |||
This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their | |||
padding masks (true for padding areas, false otherwise). | |||
""" | |||
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) | |||
batch_shape = [len(tensor_list)] + max_size | |||
b, c, h, w = batch_shape | |||
dtype = tensor_list[0].dtype | |||
device = tensor_list[0].device | |||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | |||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) | |||
for img, pad_img, m in zip(tensor_list, tensor, mask): | |||
pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) | |||
m[:img.shape[1], :img.shape[2]] = False | |||
return NestedTensor(tensor, mask) | |||
def nested_tensor_from_videos_list(videos_list: List[Tensor]): | |||
""" | |||
This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded | |||
videos (shape [T, B, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape | |||
[T, B, PH, PW]. | |||
""" | |||
max_size = _max_by_axis([list(img.shape) for img in videos_list]) | |||
padded_batch_shape = [len(videos_list)] + max_size | |||
b, t, c, h, w = padded_batch_shape | |||
dtype = videos_list[0].dtype | |||
device = videos_list[0].device | |||
padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) | |||
videos_pad_masks = torch.ones((b, t, h, w), | |||
dtype=torch.bool, | |||
device=device) | |||
for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, | |||
padded_videos, | |||
videos_pad_masks): | |||
pad_vid_frames[:vid_frames.shape[0], :, :vid_frames. | |||
shape[2], :vid_frames.shape[3]].copy_(vid_frames) | |||
vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames. | |||
shape[3]] = False | |||
# transpose the temporal and batch dims and create a NestedTensor: | |||
return NestedTensor( | |||
padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) | |||
def setup_for_distributed(is_master): | |||
""" | |||
This function disables printing when not in master process | |||
""" | |||
import builtins as __builtin__ | |||
builtin_print = __builtin__.print | |||
def print(*args, **kwargs): | |||
force = kwargs.pop('force', False) | |||
if is_master or force: | |||
builtin_print(*args, **kwargs) | |||
__builtin__.print = print | |||
def is_dist_avail_and_initialized(): | |||
if not dist.is_available(): | |||
return False | |||
if not dist.is_initialized(): | |||
return False | |||
return True | |||
def get_world_size(): | |||
if not is_dist_avail_and_initialized(): | |||
return 1 | |||
return dist.get_world_size() | |||
def get_rank(): | |||
if not is_dist_avail_and_initialized(): | |||
return 0 | |||
return dist.get_rank() | |||
def is_main_process(): | |||
return get_rank() == 0 | |||
def save_on_master(*args, **kwargs): | |||
if is_main_process(): | |||
torch.save(*args, **kwargs) | |||
def interpolate(input, | |||
size=None, | |||
scale_factor=None, | |||
mode='nearest', | |||
align_corners=None): | |||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor | |||
""" | |||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes. | |||
This will eventually be supported natively by PyTorch, and this | |||
class can go away. | |||
""" | |||
if float(torchvision.__version__.split('.')[1]) < 7.0: | |||
if input.numel() > 0: | |||
return torch.nn.functional.interpolate(input, size, scale_factor, | |||
mode, align_corners) | |||
output_shape = _output_size(2, input, size, scale_factor) | |||
output_shape = list(input.shape[:-2]) + list(output_shape) | |||
return _new_empty_tensor(input, output_shape) | |||
else: | |||
return torchvision.ops.misc.interpolate(input, size, scale_factor, | |||
mode, align_corners) |
@@ -0,0 +1,128 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
import torch | |||
import torch.nn.functional as F | |||
from einops import rearrange | |||
from torch import nn | |||
from .backbone import init_backbone | |||
from .misc import NestedTensor | |||
from .multimodal_transformer import MultimodalTransformer | |||
from .segmentation import FPNSpatialDecoder | |||
class MTTR(nn.Module): | |||
""" The main module of the Multimodal Tracking Transformer """ | |||
def __init__(self, | |||
num_queries, | |||
mask_kernels_dim=8, | |||
aux_loss=False, | |||
**kwargs): | |||
""" | |||
Parameters: | |||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects | |||
MTTR can detect in a single image. In our paper we use 50 in all settings. | |||
mask_kernels_dim: dim of the segmentation kernels and of the feature maps outputted by the spatial decoder. | |||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. | |||
""" | |||
super().__init__() | |||
self.backbone = init_backbone(**kwargs) | |||
self.transformer = MultimodalTransformer(**kwargs) | |||
d_model = self.transformer.d_model | |||
self.is_referred_head = nn.Linear( | |||
d_model, | |||
2) # binary 'is referred?' prediction head for object queries | |||
self.instance_kernels_head = MLP( | |||
d_model, d_model, output_dim=mask_kernels_dim, num_layers=2) | |||
self.obj_queries = nn.Embedding( | |||
num_queries, d_model) # pos embeddings for the object queries | |||
self.vid_embed_proj = nn.Conv2d( | |||
self.backbone.layer_output_channels[-1], d_model, kernel_size=1) | |||
self.spatial_decoder = FPNSpatialDecoder( | |||
d_model, self.backbone.layer_output_channels[:-1][::-1], | |||
mask_kernels_dim) | |||
self.aux_loss = aux_loss | |||
def forward(self, samples: NestedTensor, valid_indices, text_queries): | |||
"""The forward expects a NestedTensor, which consists of: | |||
- samples.tensor: Batched frames of shape [time x batch_size x 3 x H x W] | |||
- samples.mask: A binary mask of shape [time x batch_size x H x W], containing 1 on padded pixels | |||
It returns a dict with the following elements: | |||
- "pred_is_referred": The reference prediction logits for all queries. | |||
Shape: [time x batch_size x num_queries x 2] | |||
- "pred_masks": The mask logits for all queries. | |||
Shape: [time x batch_size x num_queries x H_mask x W_mask] | |||
- "aux_outputs": Optional, only returned when auxiliary losses are activated. It is a list of | |||
dictionaries containing the two above keys for each decoder layer. | |||
""" | |||
backbone_out = self.backbone(samples) | |||
# keep only the valid frames (frames which are annotated): | |||
# (for example, in a2d-sentences only the center frame in each window is annotated). | |||
for layer_out in backbone_out: | |||
layer_out.tensors = layer_out.tensors.index_select( | |||
0, valid_indices) | |||
layer_out.mask = layer_out.mask.index_select(0, valid_indices) | |||
bbone_final_layer_output = backbone_out[-1] | |||
vid_embeds, vid_pad_mask = bbone_final_layer_output.decompose() | |||
T, B, _, _, _ = vid_embeds.shape | |||
vid_embeds = rearrange(vid_embeds, 't b c h w -> (t b) c h w') | |||
vid_embeds = self.vid_embed_proj(vid_embeds) | |||
vid_embeds = rearrange( | |||
vid_embeds, '(t b) c h w -> t b c h w', t=T, b=B) | |||
transformer_out = self.transformer(vid_embeds, vid_pad_mask, | |||
text_queries, | |||
self.obj_queries.weight) | |||
# hs is: [L, T, B, N, D] where L is number of decoder layers | |||
# vid_memory is: [T, B, D, H, W] | |||
# txt_memory is a list of length T*B of [S, C] where S might be different for each sentence | |||
# encoder_middle_layer_outputs is a list of [T, B, H, W, D] | |||
hs, vid_memory, txt_memory = transformer_out | |||
vid_memory = rearrange(vid_memory, 't b d h w -> (t b) d h w') | |||
bbone_middle_layer_outputs = [ | |||
rearrange(o.tensors, 't b d h w -> (t b) d h w') | |||
for o in backbone_out[:-1][::-1] | |||
] | |||
decoded_frame_features = self.spatial_decoder( | |||
vid_memory, bbone_middle_layer_outputs) | |||
decoded_frame_features = rearrange( | |||
decoded_frame_features, '(t b) d h w -> t b d h w', t=T, b=B) | |||
instance_kernels = self.instance_kernels_head(hs) # [L, T, B, N, C] | |||
# output masks is: [L, T, B, N, H_mask, W_mask] | |||
output_masks = torch.einsum('ltbnc,tbchw->ltbnhw', instance_kernels, | |||
decoded_frame_features) | |||
outputs_is_referred = self.is_referred_head(hs) # [L, T, B, N, 2] | |||
layer_outputs = [] | |||
for pm, pir in zip(output_masks, outputs_is_referred): | |||
layer_out = {'pred_masks': pm, 'pred_is_referred': pir} | |||
layer_outputs.append(layer_out) | |||
out = layer_outputs[ | |||
-1] # the output for the last decoder layer is used by default | |||
if self.aux_loss: | |||
out['aux_outputs'] = layer_outputs[:-1] | |||
return out | |||
def num_parameters(self): | |||
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |||
class MLP(nn.Module): | |||
""" Very simple multi-layer perceptron (also called FFN)""" | |||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |||
super().__init__() | |||
self.num_layers = num_layers | |||
h = [hidden_dim] * (num_layers - 1) | |||
self.layers = nn.ModuleList( | |||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |||
def forward(self, x): | |||
for i, layer in enumerate(self.layers): | |||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |||
return x |
@@ -0,0 +1,440 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# MTTR Multimodal Transformer class. | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
import copy | |||
import os | |||
from typing import Optional | |||
import torch | |||
import torch.nn.functional as F | |||
from einops import rearrange, repeat | |||
from torch import Tensor, nn | |||
from transformers import RobertaModel, RobertaTokenizerFast | |||
from .position_encoding_2d import PositionEmbeddingSine2D | |||
os.environ[ | |||
'TOKENIZERS_PARALLELISM'] = 'false' # this disables a huggingface tokenizer warning (printed every epoch) | |||
class MultimodalTransformer(nn.Module): | |||
def __init__(self, | |||
num_encoder_layers=3, | |||
num_decoder_layers=3, | |||
text_encoder_type='roberta-base', | |||
freeze_text_encoder=True, | |||
**kwargs): | |||
super().__init__() | |||
self.d_model = kwargs['d_model'] | |||
encoder_layer = TransformerEncoderLayer(**kwargs) | |||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) | |||
decoder_layer = TransformerDecoderLayer(**kwargs) | |||
self.decoder = TransformerDecoder( | |||
decoder_layer, | |||
num_decoder_layers, | |||
norm=nn.LayerNorm(self.d_model), | |||
return_intermediate=True) | |||
self.pos_encoder_2d = PositionEmbeddingSine2D() | |||
self._reset_parameters() | |||
self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) | |||
self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... | |||
self.tokenizer = RobertaTokenizerFast.from_pretrained( | |||
text_encoder_type) | |||
self.freeze_text_encoder = freeze_text_encoder | |||
if freeze_text_encoder: | |||
for p in self.text_encoder.parameters(): | |||
p.requires_grad_(False) | |||
self.txt_proj = FeatureResizer( | |||
input_feat_size=self.text_encoder.config.hidden_size, | |||
output_feat_size=self.d_model, | |||
dropout=kwargs['dropout'], | |||
) | |||
def _reset_parameters(self): | |||
for p in self.parameters(): | |||
if p.dim() > 1: | |||
nn.init.xavier_uniform_(p) | |||
def forward(self, vid_embeds, vid_pad_mask, text_queries, obj_queries): | |||
device = vid_embeds.device | |||
t, b, _, h, w = vid_embeds.shape | |||
txt_memory, txt_pad_mask = self.forward_text(text_queries, device) | |||
# add temporal dim to txt memory & padding mask: | |||
txt_memory = repeat(txt_memory, 's b c -> s (t b) c', t=t) | |||
txt_pad_mask = repeat(txt_pad_mask, 'b s -> (t b) s', t=t) | |||
vid_embeds = rearrange(vid_embeds, 't b c h w -> (h w) (t b) c') | |||
# Concat the image & text embeddings on the sequence dimension | |||
encoder_src_seq = torch.cat((vid_embeds, txt_memory), dim=0) | |||
seq_mask = torch.cat( | |||
(rearrange(vid_pad_mask, 't b h w -> (t b) (h w)'), txt_pad_mask), | |||
dim=1) | |||
# vid_pos_embed is: [T*B, H, W, d_model] | |||
vid_pos_embed = self.pos_encoder_2d( | |||
rearrange(vid_pad_mask, 't b h w -> (t b) h w'), self.d_model) | |||
# use zeros in place of pos embeds for the text sequence: | |||
pos_embed = torch.cat( | |||
(rearrange(vid_pos_embed, 't_b h w c -> (h w) t_b c'), | |||
torch.zeros_like(txt_memory)), | |||
dim=0) | |||
memory = self.encoder( | |||
encoder_src_seq, src_key_padding_mask=seq_mask, | |||
pos=pos_embed) # [S, T*B, C] | |||
vid_memory = rearrange( | |||
memory[:h * w, :, :], | |||
'(h w) (t b) c -> t b c h w', | |||
h=h, | |||
w=w, | |||
t=t, | |||
b=b) | |||
txt_memory = memory[h * w:, :, :] | |||
txt_memory = rearrange(txt_memory, 's t_b c -> t_b s c') | |||
txt_memory = [ | |||
t_mem[~pad_mask] | |||
for t_mem, pad_mask in zip(txt_memory, txt_pad_mask) | |||
] # remove padding | |||
# add T*B dims to query embeds (was: [N, C], where N is the number of object queries): | |||
obj_queries = repeat(obj_queries, 'n c -> n (t b) c', t=t, b=b) | |||
tgt = torch.zeros_like(obj_queries) # [N, T*B, C] | |||
# hs is [L, N, T*B, C] where L is number of layers in the decoder | |||
hs = self.decoder( | |||
tgt, | |||
memory, | |||
memory_key_padding_mask=seq_mask, | |||
pos=pos_embed, | |||
query_pos=obj_queries) | |||
hs = rearrange(hs, 'l n (t b) c -> l t b n c', t=t, b=b) | |||
return hs, vid_memory, txt_memory | |||
def forward_text(self, text_queries, device): | |||
tokenized_queries = self.tokenizer.batch_encode_plus( | |||
text_queries, padding='longest', return_tensors='pt') | |||
tokenized_queries = tokenized_queries.to(device) | |||
with torch.inference_mode(mode=self.freeze_text_encoder): | |||
encoded_text = self.text_encoder(**tokenized_queries) | |||
# Transpose memory because pytorch's attention expects sequence first | |||
txt_memory = rearrange(encoded_text.last_hidden_state, | |||
'b s c -> s b c') | |||
txt_memory = self.txt_proj( | |||
txt_memory) # change text embeddings dim to model dim | |||
# Invert attention mask that we get from huggingface because its the opposite in pytorch transformer | |||
txt_pad_mask = tokenized_queries.attention_mask.ne(1).bool() # [B, S] | |||
return txt_memory, txt_pad_mask | |||
def num_parameters(self): | |||
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |||
class TransformerEncoder(nn.Module): | |||
def __init__(self, encoder_layer, num_layers, norm=None): | |||
super().__init__() | |||
self.layers = _get_clones(encoder_layer, num_layers) | |||
self.num_layers = num_layers | |||
self.norm = norm | |||
def forward(self, | |||
src, | |||
mask: Optional[Tensor] = None, | |||
src_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None): | |||
output = src | |||
for layer in self.layers: | |||
output = layer( | |||
output, | |||
src_mask=mask, | |||
src_key_padding_mask=src_key_padding_mask, | |||
pos=pos) | |||
if self.norm is not None: | |||
output = self.norm(output) | |||
return output | |||
class TransformerDecoder(nn.Module): | |||
def __init__(self, | |||
decoder_layer, | |||
num_layers, | |||
norm=None, | |||
return_intermediate=False): | |||
super().__init__() | |||
self.layers = _get_clones(decoder_layer, num_layers) | |||
self.num_layers = num_layers | |||
self.norm = norm | |||
self.return_intermediate = return_intermediate | |||
def forward(self, | |||
tgt, | |||
memory, | |||
tgt_mask: Optional[Tensor] = None, | |||
memory_mask: Optional[Tensor] = None, | |||
tgt_key_padding_mask: Optional[Tensor] = None, | |||
memory_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None, | |||
query_pos: Optional[Tensor] = None): | |||
output = tgt | |||
intermediate = [] | |||
for layer in self.layers: | |||
output = layer( | |||
output, | |||
memory, | |||
tgt_mask=tgt_mask, | |||
memory_mask=memory_mask, | |||
tgt_key_padding_mask=tgt_key_padding_mask, | |||
memory_key_padding_mask=memory_key_padding_mask, | |||
pos=pos, | |||
query_pos=query_pos) | |||
if self.return_intermediate: | |||
intermediate.append(self.norm(output)) | |||
if self.norm is not None: | |||
output = self.norm(output) | |||
if self.return_intermediate: | |||
intermediate.pop() | |||
intermediate.append(output) | |||
if self.return_intermediate: | |||
return torch.stack(intermediate) | |||
return output.unsqueeze(0) | |||
class TransformerEncoderLayer(nn.Module): | |||
def __init__(self, | |||
d_model, | |||
nheads, | |||
dim_feedforward=2048, | |||
dropout=0.1, | |||
activation='relu', | |||
normalize_before=False, | |||
**kwargs): | |||
super().__init__() | |||
self.self_attn = nn.MultiheadAttention( | |||
d_model, nheads, dropout=dropout) | |||
# Implementation of Feedforward model | |||
self.linear1 = nn.Linear(d_model, dim_feedforward) | |||
self.dropout = nn.Dropout(dropout) | |||
self.linear2 = nn.Linear(dim_feedforward, d_model) | |||
self.norm1 = nn.LayerNorm(d_model) | |||
self.norm2 = nn.LayerNorm(d_model) | |||
self.dropout1 = nn.Dropout(dropout) | |||
self.dropout2 = nn.Dropout(dropout) | |||
self.activation = _get_activation_fn(activation) | |||
self.normalize_before = normalize_before | |||
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |||
return tensor if pos is None else tensor + pos | |||
def forward_post(self, | |||
src, | |||
src_mask: Optional[Tensor] = None, | |||
src_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None): | |||
q = k = self.with_pos_embed(src, pos) | |||
src2 = self.self_attn( | |||
q, | |||
k, | |||
value=src, | |||
attn_mask=src_mask, | |||
key_padding_mask=src_key_padding_mask)[0] | |||
src = src + self.dropout1(src2) | |||
src = self.norm1(src) | |||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |||
src = src + self.dropout2(src2) | |||
src = self.norm2(src) | |||
return src | |||
def forward_pre(self, | |||
src, | |||
src_mask: Optional[Tensor] = None, | |||
src_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None): | |||
src2 = self.norm1(src) | |||
q = k = self.with_pos_embed(src2, pos) | |||
src2 = self.self_attn( | |||
q, | |||
k, | |||
value=src2, | |||
attn_mask=src_mask, | |||
key_padding_mask=src_key_padding_mask)[0] | |||
src = src + self.dropout1(src2) | |||
src2 = self.norm2(src) | |||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |||
src = src + self.dropout2(src2) | |||
return src | |||
def forward(self, | |||
src, | |||
src_mask: Optional[Tensor] = None, | |||
src_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None): | |||
if self.normalize_before: | |||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) | |||
return self.forward_post(src, src_mask, src_key_padding_mask, pos) | |||
class TransformerDecoderLayer(nn.Module): | |||
def __init__(self, | |||
d_model, | |||
nheads, | |||
dim_feedforward=2048, | |||
dropout=0.1, | |||
activation='relu', | |||
normalize_before=False, | |||
**kwargs): | |||
super().__init__() | |||
self.self_attn = nn.MultiheadAttention( | |||
d_model, nheads, dropout=dropout) | |||
self.multihead_attn = nn.MultiheadAttention( | |||
d_model, nheads, dropout=dropout) | |||
# Implementation of Feedforward model | |||
self.linear1 = nn.Linear(d_model, dim_feedforward) | |||
self.dropout = nn.Dropout(dropout) | |||
self.linear2 = nn.Linear(dim_feedforward, d_model) | |||
self.norm1 = nn.LayerNorm(d_model) | |||
self.norm2 = nn.LayerNorm(d_model) | |||
self.norm3 = nn.LayerNorm(d_model) | |||
self.dropout1 = nn.Dropout(dropout) | |||
self.dropout2 = nn.Dropout(dropout) | |||
self.dropout3 = nn.Dropout(dropout) | |||
self.activation = _get_activation_fn(activation) | |||
self.normalize_before = normalize_before | |||
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |||
return tensor if pos is None else tensor + pos | |||
def forward_post(self, | |||
tgt, | |||
memory, | |||
tgt_mask: Optional[Tensor] = None, | |||
memory_mask: Optional[Tensor] = None, | |||
tgt_key_padding_mask: Optional[Tensor] = None, | |||
memory_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None, | |||
query_pos: Optional[Tensor] = None): | |||
q = k = self.with_pos_embed(tgt, query_pos) | |||
tgt2 = self.self_attn( | |||
q, | |||
k, | |||
value=tgt, | |||
attn_mask=tgt_mask, | |||
key_padding_mask=tgt_key_padding_mask)[0] | |||
tgt = tgt + self.dropout1(tgt2) | |||
tgt = self.norm1(tgt) | |||
tgt2 = self.multihead_attn( | |||
query=self.with_pos_embed(tgt, query_pos), | |||
key=self.with_pos_embed(memory, pos), | |||
value=memory, | |||
attn_mask=memory_mask, | |||
key_padding_mask=memory_key_padding_mask)[0] | |||
tgt = tgt + self.dropout2(tgt2) | |||
tgt = self.norm2(tgt) | |||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |||
tgt = tgt + self.dropout3(tgt2) | |||
tgt = self.norm3(tgt) | |||
return tgt | |||
def forward_pre(self, | |||
tgt, | |||
memory, | |||
tgt_mask: Optional[Tensor] = None, | |||
memory_mask: Optional[Tensor] = None, | |||
tgt_key_padding_mask: Optional[Tensor] = None, | |||
memory_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None, | |||
query_pos: Optional[Tensor] = None): | |||
tgt2 = self.norm1(tgt) | |||
q = k = self.with_pos_embed(tgt2, query_pos) | |||
tgt2 = self.self_attn( | |||
q, | |||
k, | |||
value=tgt2, | |||
attn_mask=tgt_mask, | |||
key_padding_mask=tgt_key_padding_mask)[0] | |||
tgt = tgt + self.dropout1(tgt2) | |||
tgt2 = self.norm2(tgt) | |||
tgt2 = self.multihead_attn( | |||
query=self.with_pos_embed(tgt2, query_pos), | |||
key=self.with_pos_embed(memory, pos), | |||
value=memory, | |||
attn_mask=memory_mask, | |||
key_padding_mask=memory_key_padding_mask)[0] | |||
tgt = tgt + self.dropout2(tgt2) | |||
tgt2 = self.norm3(tgt) | |||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |||
tgt = tgt + self.dropout3(tgt2) | |||
return tgt | |||
def forward(self, | |||
tgt, | |||
memory, | |||
tgt_mask: Optional[Tensor] = None, | |||
memory_mask: Optional[Tensor] = None, | |||
tgt_key_padding_mask: Optional[Tensor] = None, | |||
memory_key_padding_mask: Optional[Tensor] = None, | |||
pos: Optional[Tensor] = None, | |||
query_pos: Optional[Tensor] = None): | |||
if self.normalize_before: | |||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask, | |||
tgt_key_padding_mask, | |||
memory_key_padding_mask, pos, query_pos) | |||
return self.forward_post(tgt, memory, tgt_mask, memory_mask, | |||
tgt_key_padding_mask, memory_key_padding_mask, | |||
pos, query_pos) | |||
def _get_clones(module, N): | |||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |||
class FeatureResizer(nn.Module): | |||
""" | |||
This class takes as input a set of embeddings of dimension C1 and outputs a set of | |||
embedding of dimension C2, after a linear transformation, dropout and normalization (LN). | |||
""" | |||
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): | |||
super().__init__() | |||
self.do_ln = do_ln | |||
# Object feature encoding | |||
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) | |||
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, encoder_features): | |||
x = self.fc(encoder_features) | |||
if self.do_ln: | |||
x = self.layer_norm(x) | |||
output = self.dropout(x) | |||
return output | |||
def _get_activation_fn(activation): | |||
"""Return an activation function given a string""" | |||
if activation == 'relu': | |||
return F.relu | |||
if activation == 'gelu': | |||
return F.gelu | |||
if activation == 'glu': | |||
return F.glu | |||
raise RuntimeError(F'activation should be relu/gelu, not {activation}.') |
@@ -0,0 +1,57 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
# 2D sine positional encodings for the visual features in the multimodal transformer. | |||
import math | |||
import torch | |||
from torch import Tensor, nn | |||
class PositionEmbeddingSine2D(nn.Module): | |||
""" | |||
This is a more standard version of the position embedding, very similar to the one | |||
used by the Attention is all you need paper, generalized to work on images. | |||
""" | |||
def __init__(self, temperature=10000, normalize=True, scale=None): | |||
super().__init__() | |||
self.temperature = temperature | |||
self.normalize = normalize | |||
if scale is not None and normalize is False: | |||
raise ValueError('normalize should be True if scale is passed') | |||
if scale is None: | |||
scale = 2 * math.pi | |||
self.scale = scale | |||
def forward(self, mask: Tensor, hidden_dim: int): | |||
""" | |||
@param mask: a tensor of shape [B, H, W] | |||
@param hidden_dim: int | |||
@return: | |||
""" | |||
num_pos_feats = hidden_dim // 2 | |||
not_mask = ~mask | |||
y_embed = not_mask.cumsum(1, dtype=torch.float32) | |||
x_embed = not_mask.cumsum(2, dtype=torch.float32) | |||
if self.normalize: | |||
eps = 1e-6 | |||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |||
dim_t = torch.arange( | |||
num_pos_feats, dtype=torch.float32, device=mask.device) | |||
dim_t = self.temperature**(2 * (dim_t // 2) / num_pos_feats) | |||
pos_x = x_embed[:, :, :, None] / dim_t | |||
pos_y = y_embed[:, :, :, None] / dim_t | |||
pos_x = torch.stack( | |||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), | |||
dim=4).flatten(3) | |||
pos_y = torch.stack( | |||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), | |||
dim=4).flatten(3) | |||
pos = torch.cat((pos_y, pos_x), dim=3) | |||
return pos |
@@ -0,0 +1,119 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
import numpy as np | |||
import pycocotools.mask as mask_util | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from einops import rearrange | |||
class A2DSentencesPostProcess(nn.Module): | |||
""" | |||
This module converts the model's output into the format expected by the coco api for the given task | |||
""" | |||
def __init__(self): | |||
super(A2DSentencesPostProcess, self).__init__() | |||
@torch.inference_mode() | |||
def forward(self, outputs, resized_padded_sample_size, | |||
resized_sample_sizes, orig_sample_sizes): | |||
""" Perform the computation | |||
Parameters: | |||
outputs: raw outputs of the model | |||
resized_padded_sample_size: size of samples (input to model) after size augmentation + padding. | |||
resized_sample_sizes: size of samples after size augmentation but without padding. | |||
orig_sample_sizes: original size of the samples (no augmentations or padding) | |||
""" | |||
pred_is_referred = outputs['pred_is_referred'] | |||
prob = F.softmax(pred_is_referred, dim=-1) | |||
scores = prob[..., 0] | |||
pred_masks = outputs['pred_masks'] | |||
pred_masks = F.interpolate( | |||
pred_masks, | |||
size=resized_padded_sample_size, | |||
mode='bilinear', | |||
align_corners=False) | |||
pred_masks = (pred_masks.sigmoid() > 0.5) | |||
processed_pred_masks, rle_masks = [], [] | |||
for f_pred_masks, resized_size, orig_size in zip( | |||
pred_masks, resized_sample_sizes, orig_sample_sizes): | |||
f_mask_h, f_mask_w = resized_size # resized shape without padding | |||
# remove the samples' padding | |||
f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, : | |||
f_mask_w].unsqueeze(1) | |||
# resize the samples back to their original dataset (target) size for evaluation | |||
f_pred_masks_processed = F.interpolate( | |||
f_pred_masks_no_pad.float(), size=orig_size, mode='nearest') | |||
f_pred_rle_masks = [ | |||
mask_util.encode( | |||
np.array( | |||
mask[0, :, :, np.newaxis], dtype=np.uint8, | |||
order='F'))[0] | |||
for mask in f_pred_masks_processed.cpu() | |||
] | |||
processed_pred_masks.append(f_pred_masks_processed) | |||
rle_masks.append(f_pred_rle_masks) | |||
predictions = [{ | |||
'scores': s, | |||
'masks': m, | |||
'rle_masks': rle | |||
} for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] | |||
return predictions | |||
class ReferYoutubeVOSPostProcess(nn.Module): | |||
""" | |||
This module converts the model's output into the format expected by the coco api for the given task | |||
""" | |||
def __init__(self): | |||
super(ReferYoutubeVOSPostProcess, self).__init__() | |||
@torch.inference_mode() | |||
def forward(self, outputs, videos_metadata, samples_shape_with_padding): | |||
""" Perform the computation | |||
Parameters: | |||
outputs: raw outputs of the model | |||
videos_metadata: a dictionary with each video's metadata. | |||
samples_shape_with_padding: size of the batch frames with padding. | |||
""" | |||
pred_is_referred = outputs['pred_is_referred'] | |||
prob_is_referred = F.softmax(pred_is_referred, dim=-1) | |||
# note we average on the temporal dim to compute score per trajectory: | |||
trajectory_scores = prob_is_referred[..., 0].mean(dim=0) | |||
pred_trajectory_indices = torch.argmax(trajectory_scores, dim=-1) | |||
pred_masks = rearrange(outputs['pred_masks'], | |||
't b nq h w -> b t nq h w') | |||
# keep only the masks of the chosen trajectories: | |||
b = pred_masks.shape[0] | |||
pred_masks = pred_masks[torch.arange(b), :, pred_trajectory_indices] | |||
# resize the predicted masks to the size of the model input (which might include padding) | |||
pred_masks = F.interpolate( | |||
pred_masks, | |||
size=samples_shape_with_padding, | |||
mode='bilinear', | |||
align_corners=False) | |||
# apply a threshold to create binary masks: | |||
pred_masks = (pred_masks.sigmoid() > 0.5) | |||
# remove the padding per video (as videos might have different resolutions and thus different padding): | |||
preds_by_video = [] | |||
for video_pred_masks, video_metadata in zip(pred_masks, | |||
videos_metadata): | |||
# size of the model input batch frames without padding: | |||
resized_h, resized_w = video_metadata['resized_frame_size'] | |||
video_pred_masks = video_pred_masks[:, :resized_h, : | |||
resized_w].unsqueeze( | |||
1) # remove the padding | |||
# resize the masks back to their original frames dataset size for evaluation: | |||
original_frames_size = video_metadata['original_frame_size'] | |||
tuple_size = tuple(original_frames_size.cpu().numpy()) | |||
video_pred_masks = F.interpolate( | |||
video_pred_masks.float(), size=tuple_size, mode='nearest') | |||
video_pred_masks = video_pred_masks.to(torch.uint8).cpu() | |||
# combine the predicted masks and the video metadata to create a final predictions dict: | |||
video_pred = {**video_metadata, **{'pred_masks': video_pred_masks}} | |||
preds_by_video.append(video_pred) | |||
return preds_by_video |
@@ -0,0 +1,137 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
from typing import List | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch import Tensor | |||
class FPNSpatialDecoder(nn.Module): | |||
""" | |||
An FPN-like spatial decoder. Generates high-res, semantically rich features which serve as the base for creating | |||
instance segmentation masks. | |||
""" | |||
def __init__(self, context_dim, fpn_dims, mask_kernels_dim=8): | |||
super().__init__() | |||
inter_dims = [ | |||
context_dim, context_dim // 2, context_dim // 4, context_dim // 8, | |||
context_dim // 16 | |||
] | |||
self.lay1 = torch.nn.Conv2d(context_dim, inter_dims[0], 3, padding=1) | |||
self.gn1 = torch.nn.GroupNorm(8, inter_dims[0]) | |||
self.lay2 = torch.nn.Conv2d(inter_dims[0], inter_dims[1], 3, padding=1) | |||
self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) | |||
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) | |||
self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) | |||
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) | |||
self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) | |||
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) | |||
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) | |||
self.context_dim = context_dim | |||
self.add_extra_layer = len(fpn_dims) == 3 | |||
if self.add_extra_layer: | |||
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) | |||
self.lay5 = torch.nn.Conv2d( | |||
inter_dims[3], inter_dims[4], 3, padding=1) | |||
self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) | |||
self.out_lay = torch.nn.Conv2d( | |||
inter_dims[4], mask_kernels_dim, 3, padding=1) | |||
else: | |||
self.out_lay = torch.nn.Conv2d( | |||
inter_dims[3], mask_kernels_dim, 3, padding=1) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
nn.init.kaiming_uniform_(m.weight, a=1) | |||
nn.init.constant_(m.bias, 0) | |||
def forward(self, x: Tensor, layer_features: List[Tensor]): | |||
x = self.lay1(x) | |||
x = self.gn1(x) | |||
x = F.relu(x) | |||
x = self.lay2(x) | |||
x = self.gn2(x) | |||
x = F.relu(x) | |||
cur_fpn = self.adapter1(layer_features[0]) | |||
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') | |||
x = self.lay3(x) | |||
x = self.gn3(x) | |||
x = F.relu(x) | |||
cur_fpn = self.adapter2(layer_features[1]) | |||
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') | |||
x = self.lay4(x) | |||
x = self.gn4(x) | |||
x = F.relu(x) | |||
if self.add_extra_layer: | |||
cur_fpn = self.adapter3(layer_features[2]) | |||
x = cur_fpn + F.interpolate( | |||
x, size=cur_fpn.shape[-2:], mode='nearest') | |||
x = self.lay5(x) | |||
x = self.gn5(x) | |||
x = F.relu(x) | |||
x = self.out_lay(x) | |||
return x | |||
def num_parameters(self): | |||
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |||
def dice_loss(inputs, targets, num_masks): | |||
""" | |||
Compute the DICE loss, similar to generalized IOU for masks | |||
Args: | |||
inputs: A float tensor of arbitrary shape. | |||
The predictions for each example. | |||
targets: A float tensor with the same shape as inputs. Stores the binary | |||
classification label for each element in inputs | |||
(0 for the negative class and 1 for the positive class). | |||
""" | |||
inputs = inputs.sigmoid() | |||
numerator = 2 * (inputs * targets).sum(1) | |||
denominator = inputs.sum(-1) + targets.sum(-1) | |||
loss = 1 - (numerator + 1) / (denominator + 1) | |||
return loss.sum() / num_masks | |||
def sigmoid_focal_loss(inputs, | |||
targets, | |||
num_masks, | |||
alpha: float = 0.25, | |||
gamma: float = 2): | |||
""" | |||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |||
Args: | |||
inputs: A float tensor of arbitrary shape. | |||
The predictions for each example. | |||
targets: A float tensor with the same shape as inputs. Stores the binary | |||
classification label for each element in inputs | |||
(0 for the negative class and 1 for the positive class). | |||
alpha: (optional) Weighting factor in range (0,1) to balance | |||
positive vs negative examples. Default = -1 (no weighting). | |||
gamma: Exponent of the modulating factor (1 - p_t) to | |||
balance easy vs hard examples. | |||
Returns: | |||
Loss tensor | |||
""" | |||
prob = inputs.sigmoid() | |||
ce_loss = F.binary_cross_entropy_with_logits( | |||
inputs, targets, reduction='none') | |||
p_t = prob * targets + (1 - prob) * (1 - targets) | |||
loss = ce_loss * ((1 - p_t)**gamma) | |||
if alpha >= 0: | |||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |||
loss = alpha_t * loss | |||
return loss.mean(1).sum() / num_masks |
@@ -0,0 +1,731 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# Modified from Video-Swin-Transformer https://github.com/SwinTransformer/Video-Swin-Transformer | |||
from functools import lru_cache, reduce | |||
from operator import mul | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.utils.checkpoint as checkpoint | |||
from einops import rearrange | |||
from timm.models.layers import DropPath, trunc_normal_ | |||
class Mlp(nn.Module): | |||
""" Multilayer perceptron.""" | |||
def __init__(self, | |||
in_features, | |||
hidden_features=None, | |||
out_features=None, | |||
act_layer=nn.GELU, | |||
drop=0.): | |||
super().__init__() | |||
out_features = out_features or in_features | |||
hidden_features = hidden_features or in_features | |||
self.fc1 = nn.Linear(in_features, hidden_features) | |||
self.act = act_layer() | |||
self.fc2 = nn.Linear(hidden_features, out_features) | |||
self.drop = nn.Dropout(drop) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act(x) | |||
x = self.drop(x) | |||
x = self.fc2(x) | |||
x = self.drop(x) | |||
return x | |||
def window_partition(x, window_size): | |||
""" | |||
Args: | |||
x: (B, D, H, W, C) | |||
window_size (tuple[int]): window size | |||
Returns: | |||
windows: (B*num_windows, window_size*window_size, C) | |||
""" | |||
B, D, H, W, C = x.shape | |||
x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], | |||
window_size[1], W // window_size[2], window_size[2], C) | |||
windows = x.permute(0, 1, 3, 5, 2, 4, 6, | |||
7).contiguous().view(-1, reduce(mul, window_size), C) | |||
return windows | |||
def window_reverse(windows, window_size, B, D, H, W): | |||
""" | |||
Args: | |||
windows: (B*num_windows, window_size, window_size, C) | |||
window_size (tuple[int]): Window size | |||
H (int): Height of image | |||
W (int): Width of image | |||
Returns: | |||
x: (B, D, H, W, C) | |||
""" | |||
x = windows.view(B, D // window_size[0], H // window_size[1], | |||
W // window_size[2], window_size[0], window_size[1], | |||
window_size[2], -1) | |||
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) | |||
return x | |||
def get_window_size(x_size, window_size, shift_size=None): | |||
use_window_size = list(window_size) | |||
if shift_size is not None: | |||
use_shift_size = list(shift_size) | |||
for i in range(len(x_size)): | |||
if x_size[i] <= window_size[i]: | |||
use_window_size[i] = x_size[i] | |||
if shift_size is not None: | |||
use_shift_size[i] = 0 | |||
if shift_size is None: | |||
return tuple(use_window_size) | |||
else: | |||
return tuple(use_window_size), tuple(use_shift_size) | |||
class WindowAttention3D(nn.Module): | |||
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |||
It supports both of shifted and non-shifted window. | |||
Args: | |||
dim (int): Number of input channels. | |||
window_size (tuple[int]): The temporal length, height and width of the window. | |||
num_heads (int): Number of attention heads. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||
""" | |||
def __init__(self, | |||
dim, | |||
window_size, | |||
num_heads, | |||
qkv_bias=False, | |||
qk_scale=None, | |||
attn_drop=0., | |||
proj_drop=0.): | |||
super().__init__() | |||
self.dim = dim | |||
self.window_size = window_size # Wd, Wh, Ww | |||
self.num_heads = num_heads | |||
head_dim = dim // num_heads | |||
self.scale = qk_scale or head_dim**-0.5 | |||
# define a parameter table of relative position bias | |||
wd, wh, ww = window_size | |||
self.relative_position_bias_table = nn.Parameter( | |||
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) | |||
# get pair-wise relative position index for each token inside the window | |||
coords_d = torch.arange(self.window_size[0]) | |||
coords_h = torch.arange(self.window_size[1]) | |||
coords_w = torch.arange(self.window_size[2]) | |||
coords = torch.stack(torch.meshgrid(coords_d, coords_h, | |||
coords_w)) # 3, Wd, Wh, Ww | |||
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww | |||
relative_coords = coords_flatten[:, :, | |||
None] - coords_flatten[:, | |||
None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww | |||
relative_coords = relative_coords.permute( | |||
1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 | |||
relative_coords[:, :, | |||
0] += self.window_size[0] - 1 # shift to start from 0 | |||
relative_coords[:, :, 1] += self.window_size[1] - 1 | |||
relative_coords[:, :, 2] += self.window_size[2] - 1 | |||
relative_coords[:, :, 0] *= (2 * self.window_size[1] | |||
- 1) * (2 * self.window_size[2] - 1) | |||
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) | |||
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww | |||
self.register_buffer('relative_position_index', | |||
relative_position_index) | |||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||
self.attn_drop = nn.Dropout(attn_drop) | |||
self.proj = nn.Linear(dim, dim) | |||
self.proj_drop = nn.Dropout(proj_drop) | |||
trunc_normal_(self.relative_position_bias_table, std=.02) | |||
self.softmax = nn.Softmax(dim=-1) | |||
def forward(self, x, mask=None): | |||
""" Forward function. | |||
Args: | |||
x: input features with shape of (num_windows*B, N, C) | |||
mask: (0/-inf) mask with shape of (num_windows, N, N) or None | |||
""" | |||
B_, N, C = x.shape | |||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, | |||
C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C | |||
q = q * self.scale | |||
attn = q @ k.transpose(-2, -1) | |||
relative_position_bias = self.relative_position_bias_table[ | |||
self.relative_position_index[:N, :N].reshape(-1)].reshape( | |||
N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH | |||
relative_position_bias = relative_position_bias.permute( | |||
2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww | |||
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N | |||
if mask is not None: | |||
nW = mask.shape[0] | |||
attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||
N) + mask.unsqueeze(1).unsqueeze(0) | |||
attn = attn.view(-1, self.num_heads, N, N) | |||
attn = self.softmax(attn) | |||
else: | |||
attn = self.softmax(attn) | |||
attn = self.attn_drop(attn) | |||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||
x = self.proj(x) | |||
x = self.proj_drop(x) | |||
return x | |||
class SwinTransformerBlock3D(nn.Module): | |||
""" Swin Transformer Block. | |||
Args: | |||
dim (int): Number of input channels. | |||
num_heads (int): Number of attention heads. | |||
window_size (tuple[int]): Window size. | |||
shift_size (tuple[int]): Shift size for SW-MSA. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
drop (float, optional): Dropout rate. Default: 0.0 | |||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
""" | |||
def __init__(self, | |||
dim, | |||
num_heads, | |||
window_size=(2, 7, 7), | |||
shift_size=(0, 0, 0), | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
act_layer=nn.GELU, | |||
norm_layer=nn.LayerNorm, | |||
use_checkpoint=False): | |||
super().__init__() | |||
self.dim = dim | |||
self.num_heads = num_heads | |||
self.window_size = window_size | |||
self.shift_size = shift_size | |||
self.mlp_ratio = mlp_ratio | |||
self.use_checkpoint = use_checkpoint | |||
assert 0 <= self.shift_size[0] < self.window_size[ | |||
0], 'shift_size must in 0-window_size' | |||
assert 0 <= self.shift_size[1] < self.window_size[ | |||
1], 'shift_size must in 0-window_size' | |||
assert 0 <= self.shift_size[2] < self.window_size[ | |||
2], 'shift_size must in 0-window_size' | |||
self.norm1 = norm_layer(dim) | |||
self.attn = WindowAttention3D( | |||
dim, | |||
window_size=self.window_size, | |||
num_heads=num_heads, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
attn_drop=attn_drop, | |||
proj_drop=drop) | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
self.norm2 = norm_layer(dim) | |||
mlp_hidden_dim = int(dim * mlp_ratio) | |||
self.mlp = Mlp( | |||
in_features=dim, | |||
hidden_features=mlp_hidden_dim, | |||
act_layer=act_layer, | |||
drop=drop) | |||
def forward_part1(self, x, mask_matrix): | |||
B, D, H, W, C = x.shape | |||
window_size, shift_size = get_window_size((D, H, W), self.window_size, | |||
self.shift_size) | |||
x = self.norm1(x) | |||
# pad feature maps to multiples of window size | |||
pad_l = pad_t = pad_d0 = 0 | |||
pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] | |||
pad_b = (window_size[1] - H % window_size[1]) % window_size[1] | |||
pad_r = (window_size[2] - W % window_size[2]) % window_size[2] | |||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) | |||
_, Dp, Hp, Wp, _ = x.shape | |||
# cyclic shift | |||
if any(i > 0 for i in shift_size): | |||
shifted_x = torch.roll( | |||
x, | |||
shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), | |||
dims=(1, 2, 3)) | |||
attn_mask = mask_matrix | |||
else: | |||
shifted_x = x | |||
attn_mask = None | |||
# partition windows | |||
x_windows = window_partition(shifted_x, | |||
window_size) # B*nW, Wd*Wh*Ww, C | |||
# W-MSA/SW-MSA | |||
attn_windows = self.attn( | |||
x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C | |||
# merge windows | |||
attn_windows = attn_windows.view(-1, *(window_size + (C, ))) | |||
shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, | |||
Wp) # B D' H' W' C | |||
# reverse cyclic shift | |||
if any(i > 0 for i in shift_size): | |||
x = torch.roll( | |||
shifted_x, | |||
shifts=(shift_size[0], shift_size[1], shift_size[2]), | |||
dims=(1, 2, 3)) | |||
else: | |||
x = shifted_x | |||
if pad_d1 > 0 or pad_r > 0 or pad_b > 0: | |||
x = x[:, :D, :H, :W, :].contiguous() | |||
return x | |||
def forward_part2(self, x): | |||
return self.drop_path(self.mlp(self.norm2(x))) | |||
def forward(self, x, mask_matrix): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, D, H, W, C). | |||
mask_matrix: Attention mask for cyclic shift. | |||
""" | |||
shortcut = x | |||
if self.use_checkpoint: | |||
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) | |||
else: | |||
x = self.forward_part1(x, mask_matrix) | |||
x = shortcut + self.drop_path(x) | |||
if self.use_checkpoint: | |||
x = x + checkpoint.checkpoint(self.forward_part2, x) | |||
else: | |||
x = x + self.forward_part2(x) | |||
return x | |||
class PatchMerging(nn.Module): | |||
""" Patch Merging Layer | |||
Args: | |||
dim (int): Number of input channels. | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
""" | |||
def __init__(self, dim, norm_layer=nn.LayerNorm): | |||
super().__init__() | |||
self.dim = dim | |||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) | |||
self.norm = norm_layer(4 * dim) | |||
def forward(self, x): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, D, H, W, C). | |||
""" | |||
B, D, H, W, C = x.shape | |||
# padding | |||
pad_input = (H % 2 == 1) or (W % 2 == 1) | |||
if pad_input: | |||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) | |||
x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C | |||
x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C | |||
x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C | |||
x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C | |||
x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C | |||
x = self.norm(x) | |||
x = self.reduction(x) | |||
return x | |||
# cache each stage results | |||
@lru_cache() | |||
def compute_mask(D, H, W, window_size, shift_size, device): | |||
img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 | |||
cnt = 0 | |||
for d in slice(-window_size[0]), slice(-window_size[0], | |||
-shift_size[0]), slice( | |||
-shift_size[0], None): | |||
for h in slice(-window_size[1]), slice(-window_size[1], | |||
-shift_size[1]), slice( | |||
-shift_size[1], None): | |||
for w in slice(-window_size[2]), slice(-window_size[2], | |||
-shift_size[2]), slice( | |||
-shift_size[2], None): | |||
img_mask[:, d, h, w, :] = cnt | |||
cnt += 1 | |||
mask_windows = window_partition(img_mask, | |||
window_size) # nW, ws[0]*ws[1]*ws[2], 1 | |||
mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] | |||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||
attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||
float(-100.0)).masked_fill( | |||
attn_mask == 0, float(0.0)) | |||
return attn_mask | |||
class BasicLayer(nn.Module): | |||
""" A basic Swin Transformer layer for one stage. | |||
Args: | |||
dim (int): Number of feature channels | |||
depth (int): Depths of this stage. | |||
num_heads (int): Number of attention head. | |||
window_size (tuple[int]): Local window size. Default: (1,7,7). | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
drop (float, optional): Dropout rate. Default: 0.0 | |||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||
""" | |||
def __init__(self, | |||
dim, | |||
depth, | |||
num_heads, | |||
window_size=(1, 7, 7), | |||
mlp_ratio=4., | |||
qkv_bias=False, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
norm_layer=nn.LayerNorm, | |||
downsample=None, | |||
use_checkpoint=False): | |||
super().__init__() | |||
self.window_size = window_size | |||
self.shift_size = tuple(i // 2 for i in window_size) | |||
self.depth = depth | |||
self.use_checkpoint = use_checkpoint | |||
# build blocks | |||
self.blocks = nn.ModuleList([ | |||
SwinTransformerBlock3D( | |||
dim=dim, | |||
num_heads=num_heads, | |||
window_size=window_size, | |||
shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
drop=drop, | |||
attn_drop=attn_drop, | |||
drop_path=drop_path[i] | |||
if isinstance(drop_path, list) else drop_path, | |||
norm_layer=norm_layer, | |||
use_checkpoint=use_checkpoint, | |||
) for i in range(depth) | |||
]) | |||
self.downsample = downsample | |||
if self.downsample is not None: | |||
self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||
def forward(self, x): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, C, D, H, W). | |||
""" | |||
# calculate attention mask for SW-MSA | |||
B, C, D, H, W = x.shape | |||
window_size, shift_size = get_window_size((D, H, W), self.window_size, | |||
self.shift_size) | |||
x = rearrange(x, 'b c d h w -> b d h w c') | |||
Dp = int(np.ceil(D / window_size[0])) * window_size[0] | |||
Hp = int(np.ceil(H / window_size[1])) * window_size[1] | |||
Wp = int(np.ceil(W / window_size[2])) * window_size[2] | |||
attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) | |||
for blk in self.blocks: | |||
x = blk(x, attn_mask) | |||
x = x.view(B, D, H, W, -1) | |||
if self.downsample is not None: | |||
x = self.downsample(x) | |||
x = rearrange(x, 'b d h w c -> b c d h w') | |||
return x | |||
class PatchEmbed3D(nn.Module): | |||
""" Video to Patch Embedding. | |||
Args: | |||
patch_size (int): Patch token size. Default: (2,4,4). | |||
in_chans (int): Number of input video channels. Default: 3. | |||
embed_dim (int): Number of linear projection output channels. Default: 96. | |||
norm_layer (nn.Module, optional): Normalization layer. Default: None | |||
""" | |||
def __init__(self, | |||
patch_size=(2, 4, 4), | |||
in_chans=3, | |||
embed_dim=96, | |||
norm_layer=None): | |||
super().__init__() | |||
self.patch_size = patch_size | |||
self.in_chans = in_chans | |||
self.embed_dim = embed_dim | |||
self.proj = nn.Conv3d( | |||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
if norm_layer is not None: | |||
self.norm = norm_layer(embed_dim) | |||
else: | |||
self.norm = None | |||
def forward(self, x): | |||
"""Forward function.""" | |||
# padding | |||
_, _, D, H, W = x.size() | |||
if W % self.patch_size[2] != 0: | |||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) | |||
if H % self.patch_size[1] != 0: | |||
x = F.pad(x, | |||
(0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) | |||
if D % self.patch_size[0] != 0: | |||
x = F.pad( | |||
x, | |||
(0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) | |||
x = self.proj(x) # B C D Wh Ww | |||
if self.norm is not None: | |||
D, Wh, Ww = x.size(2), x.size(3), x.size(4) | |||
x = x.flatten(2).transpose(1, 2) | |||
x = self.norm(x) | |||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) | |||
return x | |||
class SwinTransformer3D(nn.Module): | |||
""" Swin Transformer backbone. | |||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - | |||
https://arxiv.org/pdf/2103.14030 | |||
Args: | |||
patch_size (int | tuple(int)): Patch size. Default: (4,4,4). | |||
in_chans (int): Number of input image channels. Default: 3. | |||
embed_dim (int): Number of linear projection output channels. Default: 96. | |||
depths (tuple[int]): Depths of each Swin Transformer stage. | |||
num_heads (tuple[int]): Number of attention head of each stage. | |||
window_size (int): Window size. Default: 7. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee | |||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. | |||
drop_rate (float): Dropout rate. | |||
attn_drop_rate (float): Attention dropout rate. Default: 0. | |||
drop_path_rate (float): Stochastic depth rate. Default: 0.2. | |||
norm_layer: Normalization layer. Default: nn.LayerNorm. | |||
patch_norm (bool): If True, add normalization after patch embedding. Default: False. | |||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||
-1 means not freezing any parameters. | |||
""" | |||
def __init__(self, | |||
pretrained=None, | |||
pretrained2d=True, | |||
patch_size=(4, 4, 4), | |||
in_chans=3, | |||
embed_dim=96, | |||
depths=[2, 2, 6, 2], | |||
num_heads=[3, 6, 12, 24], | |||
window_size=(2, 7, 7), | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop_rate=0., | |||
attn_drop_rate=0., | |||
drop_path_rate=0.2, | |||
norm_layer=nn.LayerNorm, | |||
patch_norm=False, | |||
frozen_stages=-1, | |||
use_checkpoint=False): | |||
super().__init__() | |||
self.pretrained = pretrained | |||
self.pretrained2d = pretrained2d | |||
self.num_layers = len(depths) | |||
self.embed_dim = embed_dim | |||
self.patch_norm = patch_norm | |||
self.frozen_stages = frozen_stages | |||
self.window_size = window_size | |||
self.patch_size = patch_size | |||
# split image into non-overlapping patches | |||
self.patch_embed = PatchEmbed3D( | |||
patch_size=patch_size, | |||
in_chans=in_chans, | |||
embed_dim=embed_dim, | |||
norm_layer=norm_layer if self.patch_norm else None) | |||
self.pos_drop = nn.Dropout(p=drop_rate) | |||
# stochastic depth | |||
dpr = [ | |||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||
] # stochastic depth decay rule | |||
# build layers | |||
self.layers = nn.ModuleList() | |||
for i_layer in range(self.num_layers): | |||
layer = BasicLayer( | |||
dim=int(embed_dim * 2**i_layer), | |||
depth=depths[i_layer], | |||
num_heads=num_heads[i_layer], | |||
window_size=window_size, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
drop=drop_rate, | |||
attn_drop=attn_drop_rate, | |||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |||
norm_layer=norm_layer, | |||
downsample=PatchMerging | |||
if i_layer < self.num_layers - 1 else None, | |||
use_checkpoint=use_checkpoint) | |||
self.layers.append(layer) | |||
self.num_features = int(embed_dim * 2**(self.num_layers - 1)) | |||
# add a norm layer for each output | |||
self.norm = norm_layer(self.num_features) | |||
self._freeze_stages() | |||
def _freeze_stages(self): | |||
if self.frozen_stages >= 0: | |||
self.patch_embed.eval() | |||
for param in self.patch_embed.parameters(): | |||
param.requires_grad = False | |||
if self.frozen_stages >= 1: | |||
self.pos_drop.eval() | |||
for i in range(0, self.frozen_stages): | |||
m = self.layers[i] | |||
m.eval() | |||
for param in m.parameters(): | |||
param.requires_grad = False | |||
def inflate_weights(self, logger): | |||
"""Inflate the swin2d parameters to swin3d. | |||
The differences between swin3d and swin2d mainly lie in an extra | |||
axis. To utilize the pretrained parameters in 2d model, | |||
the weight of swin2d models should be inflated to fit in the shapes of | |||
the 3d counterpart. | |||
Args: | |||
logger (logging.Logger): The logger used to print | |||
debugging infomation. | |||
""" | |||
checkpoint = torch.load(self.pretrained, map_location='cpu') | |||
state_dict = checkpoint['model'] | |||
# delete relative_position_index since we always re-init it | |||
relative_position_index_keys = [ | |||
k for k in state_dict.keys() if 'relative_position_index' in k | |||
] | |||
for k in relative_position_index_keys: | |||
del state_dict[k] | |||
# delete attn_mask since we always re-init it | |||
attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] | |||
for k in attn_mask_keys: | |||
del state_dict[k] | |||
state_dict['patch_embed.proj.weight'] = state_dict[ | |||
'patch_embed.proj.weight'].unsqueeze(2).repeat( | |||
1, 1, self.patch_size[0], 1, 1) / self.patch_size[0] | |||
# bicubic interpolate relative_position_bias_table if not match | |||
relative_position_bias_table_keys = [ | |||
k for k in state_dict.keys() if 'relative_position_bias_table' in k | |||
] | |||
for k in relative_position_bias_table_keys: | |||
relative_position_bias_table_pretrained = state_dict[k] | |||
relative_position_bias_table_current = self.state_dict()[k] | |||
L1, nH1 = relative_position_bias_table_pretrained.size() | |||
L2, nH2 = relative_position_bias_table_current.size() | |||
L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) | |||
wd = self.window_size[0] | |||
if nH1 != nH2: | |||
logger.warning(f'Error in loading {k}, passing') | |||
else: | |||
if L1 != L2: | |||
S1 = int(L1**0.5) | |||
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( | |||
relative_position_bias_table_pretrained.permute( | |||
1, 0).view(1, nH1, S1, S1), | |||
size=(2 * self.window_size[1] - 1, | |||
2 * self.window_size[2] - 1), | |||
mode='bicubic') | |||
relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( | |||
nH2, L2).permute(1, 0) | |||
state_dict[k] = relative_position_bias_table_pretrained.repeat( | |||
2 * wd - 1, 1) | |||
msg = self.load_state_dict(state_dict, strict=False) | |||
logger.info(msg) | |||
logger.info(f"=> loaded successfully '{self.pretrained}'") | |||
del checkpoint | |||
torch.cuda.empty_cache() | |||
def forward(self, x): | |||
"""Forward function.""" | |||
x = self.patch_embed(x) | |||
x = self.pos_drop(x) | |||
for layer in self.layers: | |||
x = layer(x.contiguous()) | |||
x = rearrange(x, 'n c d h w -> n d h w c') | |||
x = self.norm(x) | |||
x = rearrange(x, 'n d h w c -> n c d h w') | |||
return x | |||
def train(self, mode=True): | |||
"""Convert the model into training mode while keep layers freezed.""" | |||
super(SwinTransformer3D, self).train(mode) | |||
self._freeze_stages() |
@@ -417,6 +417,12 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.video_summarization: [OutputKeys.OUTPUT], | |||
# referring video object segmentation result for a single video | |||
# { | |||
# "masks": [np.array # 2D array with shape [height, width]] | |||
# } | |||
Tasks.referring_video_object_segmentation: [OutputKeys.MASKS], | |||
# ============ nlp tasks =================== | |||
# text classification result for single sample | |||
@@ -202,6 +202,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), | |||
Tasks.product_segmentation: (Pipelines.product_segmentation, | |||
'damo/cv_F3Net_product-segmentation'), | |||
Tasks.referring_video_object_segmentation: | |||
(Pipelines.referring_video_object_segmentation, | |||
'damo/cv_swin-t_referring_video-object-segmentation'), | |||
} | |||
@@ -58,6 +58,7 @@ if TYPE_CHECKING: | |||
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | |||
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | |||
from .hand_static_pipeline import HandStaticPipeline | |||
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline | |||
else: | |||
_import_structure = { | |||
@@ -128,6 +129,9 @@ else: | |||
['FacialExpressionRecognitionPipeline'], | |||
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | |||
'hand_static_pipeline': ['HandStaticPipeline'], | |||
'referring_video_object_segmentation_pipeline': [ | |||
'ReferringVideoObjectSegmentationPipeline' | |||
], | |||
} | |||
import sys | |||
@@ -0,0 +1,193 @@ | |||
# The implementation here is modified based on MTTR, | |||
# originally Apache 2.0 License and publicly avaialbe at https://github.com/mttr2021/MTTR | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import numpy as np | |||
import torch | |||
import torchvision | |||
import torchvision.transforms.functional as F | |||
from einops import rearrange | |||
from moviepy.editor import AudioFileClip, ImageSequenceClip, VideoFileClip | |||
from PIL import Image, ImageDraw, ImageFont, ImageOps | |||
from tqdm import tqdm | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.referring_video_object_segmentation, | |||
module_name=Pipelines.referring_video_object_segmentation) | |||
class ReferringVideoObjectSegmentationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
"""use `model` to create a referring video object segmentation pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub | |||
""" | |||
_device = kwargs.pop('device', 'gpu') | |||
if torch.cuda.is_available() and _device == 'gpu': | |||
self.device = 'gpu' | |||
else: | |||
self.device = 'cpu' | |||
super().__init__(model=model, device=self.device, **kwargs) | |||
logger.info('Load model done!') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
""" | |||
Args: | |||
input: path of the input video | |||
""" | |||
assert isinstance(input, tuple) and len( | |||
input | |||
) == 4, 'error - input type must be tuple and input length must be 4' | |||
self.input_video_pth, text_queries, start_pt, end_pt = input | |||
assert 0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long' | |||
assert 1 <= len( | |||
text_queries) <= 2, 'error - 1-2 input text queries are expected' | |||
# extract the relevant subclip: | |||
self.input_clip_pth = 'input_clip.mp4' | |||
with VideoFileClip(self.input_video_pth) as video: | |||
subclip = video.subclip(start_pt, end_pt) | |||
subclip.write_videofile(self.input_clip_pth) | |||
self.window_length = 24 # length of window during inference | |||
self.window_overlap = 6 # overlap (in frames) between consecutive windows | |||
self.video, audio, self.meta = torchvision.io.read_video( | |||
filename=self.input_clip_pth) | |||
self.video = rearrange(self.video, 't h w c -> t c h w') | |||
input_video = F.resize(self.video, size=360, max_size=640) | |||
if self.device_name == 'gpu': | |||
input_video = input_video.cuda() | |||
input_video = input_video.to(torch.float).div_(255) | |||
input_video = F.normalize( | |||
input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
video_metadata = { | |||
'resized_frame_size': input_video.shape[-2:], | |||
'original_frame_size': self.video.shape[-2:] | |||
} | |||
# partition the clip into overlapping windows of frames: | |||
windows = [ | |||
input_video[i:i + self.window_length] | |||
for i in range(0, len(input_video), self.window_length | |||
- self.window_overlap) | |||
] | |||
# clean up the text queries: | |||
self.text_queries = [' '.join(q.lower().split()) for q in text_queries] | |||
result = { | |||
'text_queries': self.text_queries, | |||
'windows': windows, | |||
'video_metadata': video_metadata | |||
} | |||
return result | |||
def forward(self, input: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
pred_masks_per_query = [] | |||
t, _, h, w = self.video.shape | |||
for text_query in tqdm(input['text_queries'], desc='text queries'): | |||
pred_masks = torch.zeros(size=(t, 1, h, w)) | |||
for i, window in enumerate( | |||
tqdm(input['windows'], desc='windows')): | |||
window_masks = self.model.inference( | |||
window=window, | |||
text_query=text_query, | |||
metadata=input['video_metadata']) | |||
win_start_idx = i * ( | |||
self.window_length - self.window_overlap) | |||
pred_masks[win_start_idx:win_start_idx | |||
+ self.window_length] = window_masks | |||
pred_masks_per_query.append(pred_masks) | |||
return pred_masks_per_query | |||
def postprocess(self, inputs) -> Dict[str, Any]: | |||
if self.model.cfg.pipeline.save_masked_video: | |||
# RGB colors for instance masks: | |||
light_blue = (41, 171, 226) | |||
purple = (237, 30, 121) | |||
dark_green = (35, 161, 90) | |||
orange = (255, 148, 59) | |||
colors = np.array([light_blue, purple, dark_green, orange]) | |||
# width (in pixels) of the black strip above the video on which the text queries will be displayed: | |||
text_border_height_per_query = 36 | |||
video_np = rearrange(self.video, | |||
't c h w -> t h w c').numpy() / 255.0 | |||
# del video | |||
pred_masks_per_frame = rearrange( | |||
torch.stack(inputs), 'q t 1 h w -> t q h w').numpy() | |||
masked_video = [] | |||
for vid_frame, frame_masks in tqdm( | |||
zip(video_np, pred_masks_per_frame), | |||
total=len(video_np), | |||
desc='applying masks...'): | |||
# apply the masks: | |||
for inst_mask, color in zip(frame_masks, colors): | |||
vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) | |||
vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) | |||
# visualize the text queries: | |||
vid_frame = ImageOps.expand( | |||
vid_frame, | |||
border=(0, len(self.text_queries) | |||
* text_border_height_per_query, 0, 0)) | |||
W, H = vid_frame.size | |||
draw = ImageDraw.Draw(vid_frame) | |||
font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) | |||
for i, (text_query, color) in enumerate( | |||
zip(self.text_queries, colors), start=1): | |||
w, h = draw.textsize(text_query, font=font) | |||
draw.text(((W - w) / 2, | |||
(text_border_height_per_query * i) - h - 3), | |||
text_query, | |||
fill=tuple(color) + (255, ), | |||
font=font) | |||
masked_video.append(np.array(vid_frame)) | |||
print(type(vid_frame)) | |||
print(type(masked_video[0])) | |||
print(masked_video[0].shape) | |||
# generate and save the output clip: | |||
assert self.model.cfg.pipeline.output_path | |||
output_clip_path = self.model.cfg.pipeline.output_path | |||
clip = ImageSequenceClip( | |||
sequence=masked_video, fps=self.meta['video_fps']) | |||
clip = clip.set_audio(AudioFileClip(self.input_clip_pth)) | |||
clip.write_videofile( | |||
output_clip_path, fps=self.meta['video_fps'], audio=True) | |||
del masked_video | |||
result = {OutputKeys.MASKS: inputs} | |||
return result | |||
def apply_mask(image, mask, color, transparency=0.7): | |||
mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) | |||
mask = mask * transparency | |||
color_matrix = np.ones(image.shape, dtype=np.float) * color | |||
out_image = color_matrix * mask + image * (1.0 - mask) | |||
return out_image |
@@ -80,6 +80,9 @@ class CVTasks(object): | |||
virtual_try_on = 'virtual-try-on' | |||
movie_scene_segmentation = 'movie-scene-segmentation' | |||
# video segmentation | |||
referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
# video editing | |||
video_inpainting = 'video-inpainting' | |||
@@ -1,4 +1,5 @@ | |||
albumentations>=1.0.3 | |||
av>=9.2.0 | |||
easydict | |||
fairscale>=0.4.1 | |||
fastai>=1.0.51 | |||
@@ -14,6 +15,7 @@ lpips | |||
ml_collections | |||
mmcls>=0.21.0 | |||
mmdet>=2.25.0 | |||
moviepy>=1.0.3 | |||
networkx>=2.5 | |||
numba | |||
onnxruntime>=1.10 | |||
@@ -0,0 +1,56 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class ReferringVideoObjectSegmentationTest(unittest.TestCase, | |||
DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = Tasks.referring_video_object_segmentation | |||
self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_referring_video_object_segmentation(self): | |||
input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' | |||
text_queries = [ | |||
'guy in black performing tricks on a bike', | |||
'a black bike used to perform tricks' | |||
] | |||
start_pt, end_pt = 4, 14 | |||
input_tuple = (input_location, text_queries, start_pt, end_pt) | |||
pp = pipeline( | |||
Tasks.referring_video_object_segmentation, model=self.model_id) | |||
result = pp(input_tuple) | |||
if result: | |||
print(result) | |||
else: | |||
raise ValueError('process error') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_referring_video_object_segmentation_with_default_task(self): | |||
input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' | |||
text_queries = [ | |||
'guy in black performing tricks on a bike', | |||
'a black bike used to perform tricks' | |||
] | |||
start_pt, end_pt = 4, 14 | |||
input_tuple = (input_location, text_queries, start_pt, end_pt) | |||
pp = pipeline(Tasks.referring_video_object_segmentation) | |||
result = pp(input_tuple) | |||
if result: | |||
print(result) | |||
else: | |||
raise ValueError('process error') | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||
if __name__ == '__main__': | |||
unittest.main() |