add video human matting task code
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10839854
master^2
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:8e4ade7a6b119e20e82a641246199b4b530759166acc1f813d7cefee65b3e1e0 | |||||
| size 63944943 | |||||
| @@ -52,6 +52,7 @@ class Models(object): | |||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| image_body_reshaping = 'image-body-reshaping' | image_body_reshaping = 'image-body-reshaping' | ||||
| video_human_matting = 'video-human-matting' | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| @@ -230,6 +231,7 @@ class Pipelines(object): | |||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| image_body_reshaping = 'flow-based-body-reshaping' | image_body_reshaping = 'flow-based-body-reshaping' | ||||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | referring_video_object_segmentation = 'referring-video-object-segmentation' | ||||
| video_human_matting = 'video-human-matting' | |||||
| # nlp tasks | # nlp tasks | ||||
| automatic_post_editing = 'automatic-post-editing' | automatic_post_editing = 'automatic-post-editing' | ||||
| @@ -0,0 +1,21 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .model import VideoMattingNetwork | |||||
| from .model import preprocess | |||||
| else: | |||||
| _import_structure = {'model': ['VideoMattingNetwork', 'preprocess']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from typing import Optional | |||||
| import numpy as np | |||||
| import torch | |||||
| import torchvision | |||||
| from torch.nn import functional as F | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.video_human_matting.models import MattingNetwork | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| @MODELS.register_module( | |||||
| Tasks.video_human_matting, module_name=Models.video_human_matting) | |||||
| class VideoMattingNetwork(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| params = torch.load(model_path, map_location='cpu') | |||||
| self.model = MattingNetwork() | |||||
| if 'model_state_dict' in params.keys(): | |||||
| params = params['model_state_dict'] | |||||
| self.model.load_state_dict(params, strict=True) | |||||
| self.model.eval() | |||||
| def preprocess(image): | |||||
| frame_np = np.float32(image) / 255.0 | |||||
| frame_np = frame_np.transpose(2, 0, 1) | |||||
| frame_tensor = torch.from_numpy(frame_np) | |||||
| image_tensor = frame_tensor[None, :, :, :] | |||||
| return image_tensor | |||||
| @@ -0,0 +1 @@ | |||||
| from .matting import MattingNetwork | |||||
| @@ -0,0 +1,330 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed from paper RVM | |||||
| paper publicly available at <https://arxiv.org/abs/2108.11515/> | |||||
| """ | |||||
| from typing import Optional | |||||
| import torch | |||||
| from torch import Tensor, nn | |||||
| class hswish(nn.Module): | |||||
| def forward(self, x): | |||||
| return torch.nn.Hardswish(inplace=True)(x) | |||||
| class scSEblock(nn.Module): | |||||
| def __init__(self, out): | |||||
| super().__init__() | |||||
| self.conv1 = nn.Sequential( | |||||
| nn.Conv2d(out, int(out / 2), 3, 1, 1), | |||||
| nn.GroupNorm(out // 8, int(out / 2)), hswish()) | |||||
| self.conv2 = nn.Sequential( | |||||
| nn.Conv2d(int(out / 2), out, 1, 1, 0), | |||||
| nn.GroupNorm(out // 4, out), | |||||
| ) | |||||
| self.avgpool = nn.AdaptiveAvgPool2d(1) | |||||
| def forward_single(self, x): | |||||
| b, c, _, _ = x.size() | |||||
| x2 = self.avgpool(x).view(b, c, 1, 1) | |||||
| x2 = self.conv1(x2) | |||||
| x2 = self.conv2(x2) | |||||
| x2 = torch.sigmoid(x2) | |||||
| out = x2 * x | |||||
| return out | |||||
| def forward_time(self, x): | |||||
| B, T, _, H, W = x.shape | |||||
| x = x.flatten(0, 1) | |||||
| out = self.forward_single(x) | |||||
| out = out.unflatten(0, (B, T)) | |||||
| return out | |||||
| def forward(self, x): | |||||
| if x.ndim == 5: | |||||
| return self.forward_time(x) | |||||
| else: | |||||
| return self.forward_single(x) | |||||
| class RecurrentDecoder(nn.Module): | |||||
| def __init__(self, feature_channels, decoder_channels): | |||||
| super().__init__() | |||||
| self.avgpool = AvgPool() | |||||
| self.decode4 = BottleneckBlock(feature_channels[3]) | |||||
| self.decode3 = UpsamplingBlock(feature_channels[3], | |||||
| feature_channels[2], 3, | |||||
| decoder_channels[0]) | |||||
| self.sc3 = scSEblock(decoder_channels[0]) | |||||
| self.decode2 = UpsamplingBlock(decoder_channels[0], | |||||
| feature_channels[1], 3, | |||||
| decoder_channels[1]) | |||||
| self.sc2 = scSEblock(decoder_channels[1]) | |||||
| self.decode1 = UpsamplingBlock(decoder_channels[1], | |||||
| feature_channels[0], 3, | |||||
| decoder_channels[2]) | |||||
| self.sc1 = scSEblock(decoder_channels[2]) | |||||
| self.out0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) | |||||
| self.crosslevel1 = crossfeature(feature_channels[3], | |||||
| feature_channels[1]) | |||||
| self.crosslevel2 = crossfeature(feature_channels[2], | |||||
| feature_channels[0]) | |||||
| def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, | |||||
| f4: Tensor, r1: Optional[Tensor], r2: Optional[Tensor], | |||||
| r3: Optional[Tensor], r4: Optional[Tensor]): | |||||
| s2, s3, s4 = self.avgpool(s0) | |||||
| x4, r4 = self.decode4(f4, r4) | |||||
| x3, r3 = self.decode3(x4, f3, s4, r3) | |||||
| x3 = self.sc3(x3) | |||||
| f2 = self.crosslevel1(f4, f2) | |||||
| x2, r2 = self.decode2(x3, f2, s3, r2) | |||||
| x2 = self.sc2(x2) | |||||
| f1 = self.crosslevel2(f3, f1) | |||||
| x1, r1 = self.decode1(x2, f1, s2, r1) | |||||
| x1 = self.sc1(x1) | |||||
| out = self.out0(x1, s0) | |||||
| return out, r1, r2, r3, r4 | |||||
| class AvgPool(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.avgpool = nn.AvgPool2d( | |||||
| 2, 2, count_include_pad=False, ceil_mode=True) | |||||
| def forward_single_frame(self, s0): | |||||
| s1 = self.avgpool(s0) | |||||
| s2 = self.avgpool(s1) | |||||
| s3 = self.avgpool(s2) | |||||
| return s1, s2, s3 | |||||
| def forward_time_series(self, s0): | |||||
| B, T = s0.shape[:2] | |||||
| s0 = s0.flatten(0, 1) | |||||
| s1, s2, s3 = self.forward_single_frame(s0) | |||||
| s1 = s1.unflatten(0, (B, T)) | |||||
| s2 = s2.unflatten(0, (B, T)) | |||||
| s3 = s3.unflatten(0, (B, T)) | |||||
| return s1, s2, s3 | |||||
| def forward(self, s0): | |||||
| if s0.ndim == 5: | |||||
| return self.forward_time_series(s0) | |||||
| else: | |||||
| return self.forward_single_frame(s0) | |||||
| class crossfeature(nn.Module): | |||||
| def __init__(self, in_channels, out_channels): | |||||
| super().__init__() | |||||
| self.avg = nn.AdaptiveAvgPool2d(1) | |||||
| self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False) | |||||
| def forward_single_frame(self, x1, x2): | |||||
| b, c, _, _ = x1.size() | |||||
| x1 = self.avg(x1).view(b, c, 1, 1) | |||||
| x1 = self.conv(x1) | |||||
| x1 = torch.sigmoid(x1) | |||||
| x2 = x1 * x2 | |||||
| return x2 | |||||
| def forward_time_series(self, x1, x2): | |||||
| b, t = x1.shape[:2] | |||||
| x1 = x1.flatten(0, 1) | |||||
| x2 = x2.flatten(0, 1) | |||||
| x2 = self.forward_single_frame(x1, x2) | |||||
| return x2.unflatten(0, (b, t)) | |||||
| def forward(self, x1, x2): | |||||
| if x1.ndim == 5: | |||||
| return self.forward_time_series(x1, x2) | |||||
| else: | |||||
| return self.forward_single_frame(x1, x2) | |||||
| class BottleneckBlock(nn.Module): | |||||
| def __init__(self, channels): | |||||
| super().__init__() | |||||
| self.channels = channels | |||||
| self.gru = GRU(channels // 2) | |||||
| def forward(self, x, r): | |||||
| a, b = x.split(self.channels // 2, dim=-3) | |||||
| b, r = self.gru(b, r) | |||||
| x = torch.cat([a, b], dim=-3) | |||||
| return x, r | |||||
| class UpsamplingBlock(nn.Module): | |||||
| def __init__(self, in_channels, skip_channels, src_channels, out_channels): | |||||
| super().__init__() | |||||
| self.out_channels = out_channels | |||||
| self.upsample = nn.Upsample( | |||||
| scale_factor=2, mode='bilinear', align_corners=False) | |||||
| self.shortcut = nn.Sequential( | |||||
| nn.Conv2d(skip_channels, in_channels, 3, 1, 1, bias=False), | |||||
| nn.GroupNorm(in_channels // 4, in_channels), hswish()) | |||||
| self.att_skip = nn.Sequential( | |||||
| nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False), | |||||
| nn.Sigmoid()) | |||||
| self.conv = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_channels + in_channels + src_channels, | |||||
| out_channels, | |||||
| 3, | |||||
| 1, | |||||
| 1, | |||||
| bias=False), | |||||
| nn.GroupNorm(out_channels // 4, out_channels), | |||||
| hswish(), | |||||
| ) | |||||
| self.gru = GRU(out_channels // 2) | |||||
| def forward_single_frame(self, x, f, s, r: Optional[Tensor]): | |||||
| x = self.upsample(x) | |||||
| x = x[:, :, :s.size(2), :s.size(3)] | |||||
| att = self.att_skip(x) | |||||
| f = self.shortcut(f) | |||||
| f = att * f | |||||
| x = torch.cat([x, f, s], dim=1) | |||||
| x = self.conv(x) | |||||
| a, b = x.split(self.out_channels // 2, dim=1) | |||||
| b, r = self.gru(b, r) | |||||
| x = torch.cat([a, b], dim=1) | |||||
| return x, r | |||||
| def forward_time_series(self, x, f, s, r: Optional[Tensor]): | |||||
| B, T, _, H, W = s.shape | |||||
| x = x.flatten(0, 1) | |||||
| f = f.flatten(0, 1) | |||||
| s = s.flatten(0, 1) | |||||
| x = self.upsample(x) | |||||
| x = x[:, :, :H, :W] | |||||
| f = self.shortcut(f) | |||||
| att = self.att_skip(x) | |||||
| f = att * f | |||||
| x = torch.cat([x, f, s], dim=1) | |||||
| x = self.conv(x) | |||||
| x = x.unflatten(0, (B, T)) | |||||
| a, b = x.split(self.out_channels // 2, dim=2) | |||||
| b, r = self.gru(b, r) | |||||
| x = torch.cat([a, b], dim=2) | |||||
| return x, r | |||||
| def forward(self, x, f, s, r: Optional[Tensor]): | |||||
| if x.ndim == 5: | |||||
| return self.forward_time_series(x, f, s, r) | |||||
| else: | |||||
| return self.forward_single_frame(x, f, s, r) | |||||
| class OutputBlock(nn.Module): | |||||
| def __init__(self, in_channels, src_channels, out_channels): | |||||
| super().__init__() | |||||
| self.upsample = nn.Upsample( | |||||
| scale_factor=2, mode='bilinear', align_corners=False) | |||||
| self.conv = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_channels + src_channels, out_channels, 3, 1, 1, bias=False), | |||||
| nn.GroupNorm(out_channels // 2, out_channels), | |||||
| hswish(), | |||||
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |||||
| nn.GroupNorm(out_channels // 2, out_channels), | |||||
| hswish(), | |||||
| ) | |||||
| def forward_single_frame(self, x, s): | |||||
| x = self.upsample(x) | |||||
| x = x[:, :, :s.size(2), :s.size(3)] | |||||
| x = torch.cat([x, s], dim=1) | |||||
| x = self.conv(x) | |||||
| return x | |||||
| def forward_time_series(self, x, s): | |||||
| B, T, _, H, W = s.shape | |||||
| x = x.flatten(0, 1) | |||||
| s = s.flatten(0, 1) | |||||
| x = self.upsample(x) | |||||
| x = x[:, :, :H, :W] | |||||
| x = torch.cat([x, s], dim=1) | |||||
| x = self.conv(x) | |||||
| x = x.unflatten(0, (B, T)) | |||||
| return x | |||||
| def forward(self, x, s): | |||||
| if x.ndim == 5: | |||||
| return self.forward_time_series(x, s) | |||||
| else: | |||||
| return self.forward_single_frame(x, s) | |||||
| class Projection(nn.Module): | |||||
| def __init__(self, in_channels, out_channels): | |||||
| super().__init__() | |||||
| self.conv = nn.Conv2d(in_channels, out_channels, 1) | |||||
| def forward_single_frame(self, x): | |||||
| return self.conv(x) | |||||
| def forward_time_series(self, x): | |||||
| B, T = x.shape[:2] | |||||
| return self.conv(x.flatten(0, 1)).unflatten(0, (B, T)) | |||||
| def forward(self, x): | |||||
| if x.ndim == 5: | |||||
| return self.forward_time_series(x) | |||||
| else: | |||||
| return self.forward_single_frame(x) | |||||
| class GRU(nn.Module): | |||||
| def __init__(self, channels, kernel_size=3, padding=1): | |||||
| super().__init__() | |||||
| self.channels = channels | |||||
| self.ih = nn.Conv2d( | |||||
| channels * 2, channels * 2, kernel_size, padding=padding) | |||||
| self.act_ih = nn.Sigmoid() | |||||
| self.hh = nn.Conv2d( | |||||
| channels * 2, channels, kernel_size, padding=padding) | |||||
| self.act_hh = nn.Tanh() | |||||
| def forward_single_frame(self, x, pre_fea): | |||||
| fea_ih = self.ih(torch.cat([x, pre_fea], dim=1)) | |||||
| r, z = self.act_ih(fea_ih).split(self.channels, dim=1) | |||||
| fea_hh = self.hh(torch.cat([x, r * pre_fea], dim=1)) | |||||
| c = self.act_hh(fea_hh) | |||||
| fea_gru = (1 - z) * pre_fea + z * c | |||||
| return fea_gru, fea_gru | |||||
| def forward_time_series(self, x, pre_fea): | |||||
| o = [] | |||||
| for xt in x.unbind(dim=1): | |||||
| ot, pre_fea = self.forward_single_frame(xt, pre_fea) | |||||
| o.append(ot) | |||||
| o = torch.stack(o, dim=1) | |||||
| return o, pre_fea | |||||
| def forward(self, x, pre_fea): | |||||
| if pre_fea is None: | |||||
| pre_fea = torch.zeros( | |||||
| (x.size(0), x.size(-3), x.size(-2), x.size(-1)), | |||||
| device=x.device, | |||||
| dtype=x.dtype) | |||||
| if x.ndim == 5: | |||||
| return self.forward_time_series(x, pre_fea) | |||||
| else: | |||||
| return self.forward_single_frame(x, pre_fea) | |||||
| @@ -0,0 +1,64 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from DeepGuidedFilter | |||||
| publicly available at <https://github.com/wuhuikai/DeepGuidedFilter/> | |||||
| """ | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.nn import functional as F | |||||
| class DeepGuidedFilterRefiner(nn.Module): | |||||
| def __init__(self, hid_channels=16): | |||||
| super().__init__() | |||||
| self.box_filter = nn.Conv2d( | |||||
| 4, 4, kernel_size=3, padding=1, bias=False, groups=4) | |||||
| self.box_filter.weight.data[...] = 1 / 9 | |||||
| self.conv = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| 4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False), | |||||
| nn.BatchNorm2d(hid_channels), nn.ReLU(True), | |||||
| nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False), | |||||
| nn.BatchNorm2d(hid_channels), nn.ReLU(True), | |||||
| nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)) | |||||
| def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, | |||||
| base_hid): | |||||
| fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1) | |||||
| base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1) | |||||
| base_y = torch.cat([base_fgr, base_pha], dim=1) | |||||
| mean_x = self.box_filter(base_x) | |||||
| mean_y = self.box_filter(base_y) | |||||
| cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y | |||||
| var_x = self.box_filter(base_x * base_x) - mean_x * mean_x | |||||
| A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1)) | |||||
| b = mean_y - A * mean_x | |||||
| H, W = fine_src.shape[2:] | |||||
| A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False) | |||||
| b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False) | |||||
| out = A * fine_x + b | |||||
| fgr, pha = out.split([3, 1], dim=1) | |||||
| return fgr, pha | |||||
| def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, | |||||
| base_hid): | |||||
| B, T = fine_src.shape[:2] | |||||
| fgr, pha = self.forward_single_frame( | |||||
| fine_src.flatten(0, 1), base_src.flatten(0, 1), | |||||
| base_fgr.flatten(0, 1), base_pha.flatten(0, 1), | |||||
| base_hid.flatten(0, 1)) | |||||
| fgr = fgr.unflatten(0, (B, T)) | |||||
| pha = pha.unflatten(0, (B, T)) | |||||
| return fgr, pha | |||||
| def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): | |||||
| if fine_src.ndim == 5: | |||||
| return self.forward_time_series(fine_src, base_src, base_fgr, | |||||
| base_pha, base_hid) | |||||
| else: | |||||
| return self.forward_single_frame(fine_src, base_src, base_fgr, | |||||
| base_pha, base_hid) | |||||
| @@ -0,0 +1,177 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from EfficientNetV2 | |||||
| publicly available at <https://arxiv.org/abs/2104.00298> | |||||
| """ | |||||
| import torch | |||||
| import torch.nn.functional | |||||
| class SiLU(torch.nn.Module): | |||||
| """ | |||||
| [https://arxiv.org/pdf/1710.05941.pdf] | |||||
| """ | |||||
| def __init__(self, inplace: bool = False): | |||||
| super().__init__() | |||||
| self.silu = torch.nn.SiLU(inplace=inplace) | |||||
| def forward(self, x): | |||||
| return self.silu(x) | |||||
| class Conv(torch.nn.Module): | |||||
| def __init__(self, in_ch, out_ch, activation, k=1, s=1, g=1): | |||||
| super().__init__() | |||||
| self.conv = torch.nn.Conv2d( | |||||
| in_ch, out_ch, k, s, k // 2, 1, g, bias=False) | |||||
| self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.01) | |||||
| self.silu = activation | |||||
| def forward(self, x): | |||||
| return self.silu(self.norm(self.conv(x))) | |||||
| class SE(torch.nn.Module): | |||||
| """ | |||||
| [https://arxiv.org/pdf/1709.01507.pdf] | |||||
| """ | |||||
| def __init__(self, ch, r): | |||||
| super().__init__() | |||||
| self.se = torch.nn.Sequential( | |||||
| torch.nn.Conv2d(ch, ch // (4 * r), 1), torch.nn.SiLU(), | |||||
| torch.nn.Conv2d(ch // (4 * r), ch, 1), torch.nn.Sigmoid()) | |||||
| def forward(self, x): | |||||
| return x * self.se(x.mean((2, 3), keepdim=True)) | |||||
| class Residual(torch.nn.Module): | |||||
| """ | |||||
| [https://arxiv.org/pdf/1801.04381.pdf] | |||||
| """ | |||||
| def __init__(self, in_ch, out_ch, s, r, fused=True): | |||||
| super().__init__() | |||||
| identity = torch.nn.Identity() | |||||
| if fused: | |||||
| if r == 1: | |||||
| features = [Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s)] | |||||
| else: | |||||
| features = [ | |||||
| Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s), | |||||
| Conv(r * in_ch, out_ch, identity) | |||||
| ] | |||||
| else: | |||||
| if r == 1: | |||||
| features = [ | |||||
| Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s, | |||||
| r * in_ch), | |||||
| SE(r * in_ch, r), | |||||
| Conv(r * in_ch, out_ch, identity) | |||||
| ] | |||||
| else: | |||||
| features = [ | |||||
| Conv(in_ch, r * in_ch, torch.nn.SiLU()), | |||||
| Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s, | |||||
| r * in_ch), | |||||
| SE(r * in_ch, r), | |||||
| Conv(r * in_ch, out_ch, identity) | |||||
| ] | |||||
| self.add = s == 1 and in_ch == out_ch | |||||
| self.res = torch.nn.Sequential(*features) | |||||
| def forward(self, x): | |||||
| return x + self.res(x) if self.add else self.res(x) | |||||
| class EfficientNet(torch.nn.Module): | |||||
| def __init__(self, pretrained: bool = False): | |||||
| super().__init__() | |||||
| gate_fn = [True, False] | |||||
| filters = [24, 48, 64, 128, 160, 256] | |||||
| feature = [Conv(3, filters[0], torch.nn.SiLU(), 3, 2)] | |||||
| for i in range(2): | |||||
| if i == 0: | |||||
| feature.append( | |||||
| Residual(filters[0], filters[0], 1, 1, gate_fn[0])) | |||||
| else: | |||||
| feature.append( | |||||
| Residual(filters[0], filters[0], 1, 1, gate_fn[0])) | |||||
| for i in range(4): | |||||
| if i == 0: | |||||
| feature.append( | |||||
| Residual(filters[0], filters[1], 2, 4, gate_fn[0])) | |||||
| else: | |||||
| feature.append( | |||||
| Residual(filters[1], filters[1], 1, 4, gate_fn[0])) | |||||
| for i in range(4): | |||||
| if i == 0: | |||||
| feature.append( | |||||
| Residual(filters[1], filters[2], 2, 4, gate_fn[0])) | |||||
| else: | |||||
| feature.append( | |||||
| Residual(filters[2], filters[2], 1, 4, gate_fn[0])) | |||||
| for i in range(6): | |||||
| if i == 0: | |||||
| feature.append( | |||||
| Residual(filters[2], filters[3], 2, 4, gate_fn[1])) | |||||
| else: | |||||
| feature.append( | |||||
| Residual(filters[3], filters[3], 1, 4, gate_fn[1])) | |||||
| for i in range(9): | |||||
| if i == 0: | |||||
| feature.append( | |||||
| Residual(filters[3], filters[4], 1, 6, gate_fn[1])) | |||||
| else: | |||||
| feature.append( | |||||
| Residual(filters[4], filters[4], 1, 6, gate_fn[1])) | |||||
| self.feature = torch.nn.Sequential(*feature) | |||||
| def forward_single_frame(self, x): | |||||
| x = self.feature[0](x) | |||||
| x = self.feature[1](x) | |||||
| x = self.feature[2](x) | |||||
| f1 = x # 1/2 24 | |||||
| for i in range(4): | |||||
| x = self.feature[i + 3](x) | |||||
| f2 = x # 1/4 48 | |||||
| for i in range(4): | |||||
| x = self.feature[i + 7](x) | |||||
| f3 = x # 1/8 64 | |||||
| for i in range(6): | |||||
| x = self.feature[i + 11](x) | |||||
| for i in range(9): | |||||
| x = self.feature[i + 17](x) | |||||
| f5 = x # 1/16 160 | |||||
| return [f1, f2, f3, f5] | |||||
| def forward_time_series(self, x): | |||||
| B, T = x.shape[:2] | |||||
| features = self.forward_single_frame(x.flatten(0, 1)) | |||||
| features = [f.unflatten(0, (B, T)) for f in features] | |||||
| return features | |||||
| def forward(self, x): | |||||
| if x.ndim == 5: | |||||
| return self.forward_time_series(x) | |||||
| else: | |||||
| return self.forward_single_frame(x) | |||||
| def export(self): | |||||
| for m in self.modules(): | |||||
| if type(m) is Conv and hasattr(m, 'silu'): | |||||
| if isinstance(m.silu, torch.nn.SiLU): | |||||
| m.silu = SiLU() | |||||
| if type(m) is SE: | |||||
| if isinstance(m.se[1], torch.nn.SiLU): | |||||
| m.se[1] = SiLU() | |||||
| return self | |||||
| @@ -0,0 +1,94 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from Deeplab v3 | |||||
| publicly available at <https://arxiv.org/abs/1706.05587v3> | |||||
| """ | |||||
| import torch | |||||
| from torch import nn | |||||
| class ASP_OC_Module(nn.Module): | |||||
| def __init__(self, features, out_features=96, dilations=(2, 4, 8)): | |||||
| super(ASP_OC_Module, self).__init__() | |||||
| self.conv2 = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| features, | |||||
| out_features, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| bias=False), nn.BatchNorm2d(out_features)) | |||||
| self.conv3 = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| features, | |||||
| out_features, | |||||
| kernel_size=3, | |||||
| padding=dilations[0], | |||||
| dilation=dilations[0], | |||||
| bias=False), nn.BatchNorm2d(out_features)) | |||||
| self.conv4 = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| features, | |||||
| out_features, | |||||
| kernel_size=3, | |||||
| padding=dilations[1], | |||||
| dilation=dilations[1], | |||||
| bias=False), nn.BatchNorm2d(out_features)) | |||||
| self.conv5 = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| features, | |||||
| out_features, | |||||
| kernel_size=3, | |||||
| padding=dilations[2], | |||||
| dilation=dilations[2], | |||||
| bias=False), nn.BatchNorm2d(out_features)) | |||||
| self.conv_bn_dropout = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| out_features * 4, | |||||
| out_features * 2, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| bias=False), nn.InstanceNorm2d(out_features * 2), | |||||
| nn.Dropout2d(0.05)) | |||||
| def _cat_each(self, feat1, feat2, feat3, feat4, feat5): | |||||
| assert (len(feat1) == len(feat2)) | |||||
| z = [] | |||||
| for i in range(len(feat1)): | |||||
| z.append( | |||||
| torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), | |||||
| 1)) | |||||
| return z | |||||
| def forward(self, x): | |||||
| _, _, h, w = x.size() | |||||
| feat2 = self.conv2(x) | |||||
| feat3 = self.conv3(x) | |||||
| feat4 = self.conv4(x) | |||||
| feat5 = self.conv5(x) | |||||
| out = torch.cat((feat2, feat3, feat4, feat5), 1) | |||||
| output = self.conv_bn_dropout(out) | |||||
| return output | |||||
| class LRASPP(nn.Module): | |||||
| def __init__(self, in_channels, out_channels): | |||||
| super().__init__() | |||||
| self.aspp = ASP_OC_Module(in_channels, out_channels) | |||||
| def forward_single_frame(self, x): | |||||
| return self.aspp(x) | |||||
| def forward_time_series(self, x): | |||||
| B, T = x.shape[:2] | |||||
| x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T)) | |||||
| return x | |||||
| def forward(self, x): | |||||
| if x.ndim == 5: | |||||
| return self.forward_time_series(x) | |||||
| else: | |||||
| return self.forward_single_frame(x) | |||||
| @@ -0,0 +1,67 @@ | |||||
| from typing import Optional | |||||
| import torch | |||||
| from torch import Tensor | |||||
| from torch.nn import functional as F | |||||
| from .decoder import Projection, RecurrentDecoder | |||||
| from .deep_guided_filter import DeepGuidedFilterRefiner | |||||
| from .effv2 import EfficientNet | |||||
| from .lraspp import LRASPP | |||||
| class MattingNetwork(torch.nn.Module): | |||||
| def __init__(self, pretrained_backbone: bool = False): | |||||
| super().__init__() | |||||
| self.backbone = EfficientNet(pretrained_backbone) | |||||
| self.aspp = LRASPP(160, 64) | |||||
| self.decoder = RecurrentDecoder([24, 48, 64, 128], [64, 32, 24, 16]) | |||||
| self.project_mat = Projection(16, 4) | |||||
| self.project_seg = Projection(16, 1) | |||||
| self.refiner = DeepGuidedFilterRefiner() | |||||
| def forward(self, | |||||
| src: Tensor, | |||||
| r0: Optional[Tensor] = None, | |||||
| r1: Optional[Tensor] = None, | |||||
| r2: Optional[Tensor] = None, | |||||
| r3: Optional[Tensor] = None, | |||||
| downsample_ratio: float = 1, | |||||
| segmentation_pass: bool = False): | |||||
| if downsample_ratio != 1: | |||||
| src_sm = self._interpolate(src, scale_factor=downsample_ratio) | |||||
| else: | |||||
| src_sm = src | |||||
| f1, f2, f3, f4 = self.backbone(src_sm) | |||||
| f4 = self.aspp(f4) | |||||
| hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r0, r1, r2, r3) | |||||
| if not segmentation_pass: | |||||
| fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) | |||||
| if downsample_ratio != 1: | |||||
| _, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) | |||||
| pha = pha.clamp(0., 1.) | |||||
| return [pha, *rec] | |||||
| else: | |||||
| seg = self.project_seg(hid) | |||||
| return [seg, *rec] | |||||
| def _interpolate(self, x: Tensor, scale_factor: float): | |||||
| if x.ndim == 5: | |||||
| B, T = x.shape[:2] | |||||
| x = F.interpolate( | |||||
| x.flatten(0, 1), | |||||
| scale_factor=scale_factor, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| x = x.unflatten(0, (B, T)) | |||||
| else: | |||||
| x = F.interpolate( | |||||
| x, | |||||
| scale_factor=scale_factor, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| return x | |||||
| @@ -443,6 +443,12 @@ TASK_OUTPUTS = { | |||||
| Tasks.referring_video_object_segmentation: | Tasks.referring_video_object_segmentation: | ||||
| [OutputKeys.MASKS, OutputKeys.TIMESTAMPS], | [OutputKeys.MASKS, OutputKeys.TIMESTAMPS], | ||||
| # video human matting result for a single video | |||||
| # { | |||||
| # "masks": [np.array # 2D array with shape [height, width]] | |||||
| # } | |||||
| Tasks.video_human_matting: [OutputKeys.MASKS], | |||||
| # ============ nlp tasks =================== | # ============ nlp tasks =================== | ||||
| # text classification result for single sample | # text classification result for single sample | ||||
| @@ -201,6 +201,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_fft_inpainting_lama'), | 'damo/cv_fft_inpainting_lama'), | ||||
| Tasks.video_inpainting: (Pipelines.video_inpainting, | Tasks.video_inpainting: (Pipelines.video_inpainting, | ||||
| 'damo/cv_video-inpainting'), | 'damo/cv_video-inpainting'), | ||||
| Tasks.video_human_matting: (Pipelines.video_human_matting, | |||||
| 'damo/cv_effnetv2_video-human-matting'), | |||||
| Tasks.human_wholebody_keypoint: | Tasks.human_wholebody_keypoint: | ||||
| (Pipelines.human_wholebody_keypoint, | (Pipelines.human_wholebody_keypoint, | ||||
| 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), | 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), | ||||
| @@ -0,0 +1,77 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.video_human_matting import preprocess | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.video_human_matting, module_name=Pipelines.video_human_matting) | |||||
| class VideoHumanMattingPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a video human matting pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| if torch.cuda.is_available(): | |||||
| self.device = 'cuda' | |||||
| else: | |||||
| self.device = 'cpu' | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input) -> Input: | |||||
| return input | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| video_path = input['video_input_path'] | |||||
| out_path = input['output_path'] | |||||
| video_input = cv2.VideoCapture(video_path) | |||||
| fps = video_input.get(cv2.CAP_PROP_FPS) | |||||
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |||||
| success, frame = video_input.read() | |||||
| h, w = frame.shape[:2] | |||||
| scale = 512 / max(h, w) | |||||
| video_save = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) | |||||
| masks = [] | |||||
| rec = [None] * 4 | |||||
| self.model = self.model.to(self.device) | |||||
| logger.info('matting start using ', self.device) | |||||
| with torch.no_grad(): | |||||
| while True: | |||||
| if frame is None: | |||||
| break | |||||
| frame_tensor = preprocess(frame) | |||||
| pha, *rec = self.model.model( | |||||
| frame_tensor.to(self.device), *rec, downsample_ratio=scale) | |||||
| com = pha * 255 | |||||
| com = com.repeat(1, 3, 1, 1) | |||||
| com = com[0].data.cpu().numpy().transpose(1, 2, | |||||
| 0).astype(np.uint8) | |||||
| video_save.write(com) | |||||
| masks.append(com / 255) | |||||
| success, frame = video_input.read() | |||||
| logger.info('matting process done') | |||||
| video_input.release() | |||||
| video_save.release() | |||||
| return { | |||||
| OutputKeys.MASKS: masks, | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -87,6 +87,7 @@ class CVTasks(object): | |||||
| # video segmentation | # video segmentation | ||||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | referring_video_object_segmentation = 'referring-video-object-segmentation' | ||||
| video_human_matting = 'video-human-matting' | |||||
| # video editing | # video editing | ||||
| video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import sys | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class VideoHumanMattingTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model = 'damo/cv_effnetv2_video-human-matting' | |||||
| self.video_in = 'data/test/videos/video_matting_test.mp4' | |||||
| self.video_out = 'matting_out.mp4' | |||||
| self.input = { | |||||
| 'video_input_path': self.video_in, | |||||
| 'output_path': self.video_out, | |||||
| } | |||||
| def pipeline_inference(self, pipeline: Pipeline, input): | |||||
| result = pipeline(input) | |||||
| print('video matting over, results:', result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| video_human_matting = pipeline( | |||||
| Tasks.video_human_matting, model=self.model) | |||||
| self.pipeline_inference(video_human_matting, self.input) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| video_human_matting = pipeline(Tasks.video_human_matting) | |||||
| self.pipeline_inference(video_human_matting, self.input) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||