diff --git a/data/test/videos/video_matting_test.mp4 b/data/test/videos/video_matting_test.mp4 new file mode 100644 index 00000000..efdd3cb0 --- /dev/null +++ b/data/test/videos/video_matting_test.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e4ade7a6b119e20e82a641246199b4b530759166acc1f813d7cefee65b3e1e0 +size 63944943 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index afba99a7..9ee4091f 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -52,6 +52,7 @@ class Models(object): face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' image_body_reshaping = 'image-body-reshaping' + video_human_matting = 'video-human-matting' # EasyCV models yolox = 'YOLOX' @@ -230,6 +231,7 @@ class Pipelines(object): product_segmentation = 'product-segmentation' image_body_reshaping = 'flow-based-body-reshaping' referring_video_object_segmentation = 'referring-video-object-segmentation' + video_human_matting = 'video-human-matting' # nlp tasks automatic_post_editing = 'automatic-post-editing' diff --git a/modelscope/models/cv/video_human_matting/__init__.py b/modelscope/models/cv/video_human_matting/__init__.py new file mode 100644 index 00000000..7d47317c --- /dev/null +++ b/modelscope/models/cv/video_human_matting/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import VideoMattingNetwork + from .model import preprocess + +else: + _import_structure = {'model': ['VideoMattingNetwork', 'preprocess']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_human_matting/model.py b/modelscope/models/cv/video_human_matting/model.py new file mode 100644 index 00000000..98948051 --- /dev/null +++ b/modelscope/models/cv/video_human_matting/model.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Optional + +import numpy as np +import torch +import torchvision +from torch.nn import functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.video_human_matting.models import MattingNetwork +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + + +@MODELS.register_module( + Tasks.video_human_matting, module_name=Models.video_human_matting) +class VideoMattingNetwork(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params = torch.load(model_path, map_location='cpu') + self.model = MattingNetwork() + if 'model_state_dict' in params.keys(): + params = params['model_state_dict'] + self.model.load_state_dict(params, strict=True) + self.model.eval() + + +def preprocess(image): + frame_np = np.float32(image) / 255.0 + frame_np = frame_np.transpose(2, 0, 1) + frame_tensor = torch.from_numpy(frame_np) + image_tensor = frame_tensor[None, :, :, :] + return image_tensor diff --git a/modelscope/models/cv/video_human_matting/models/__init__.py b/modelscope/models/cv/video_human_matting/models/__init__.py new file mode 100644 index 00000000..471f0308 --- /dev/null +++ b/modelscope/models/cv/video_human_matting/models/__init__.py @@ -0,0 +1 @@ +from .matting import MattingNetwork diff --git a/modelscope/models/cv/video_human_matting/models/decoder.py b/modelscope/models/cv/video_human_matting/models/decoder.py new file mode 100644 index 00000000..ba82aa90 --- /dev/null +++ b/modelscope/models/cv/video_human_matting/models/decoder.py @@ -0,0 +1,330 @@ +""" +Part of the implementation is borrowed from paper RVM +paper publicly available at +""" +from typing import Optional + +import torch +from torch import Tensor, nn + + +class hswish(nn.Module): + + def forward(self, x): + return torch.nn.Hardswish(inplace=True)(x) + + +class scSEblock(nn.Module): + + def __init__(self, out): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(out, int(out / 2), 3, 1, 1), + nn.GroupNorm(out // 8, int(out / 2)), hswish()) + self.conv2 = nn.Sequential( + nn.Conv2d(int(out / 2), out, 1, 1, 0), + nn.GroupNorm(out // 4, out), + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + def forward_single(self, x): + b, c, _, _ = x.size() + x2 = self.avgpool(x).view(b, c, 1, 1) + x2 = self.conv1(x2) + x2 = self.conv2(x2) + x2 = torch.sigmoid(x2) + out = x2 * x + return out + + def forward_time(self, x): + B, T, _, H, W = x.shape + x = x.flatten(0, 1) + out = self.forward_single(x) + out = out.unflatten(0, (B, T)) + return out + + def forward(self, x): + if x.ndim == 5: + return self.forward_time(x) + else: + return self.forward_single(x) + + +class RecurrentDecoder(nn.Module): + + def __init__(self, feature_channels, decoder_channels): + super().__init__() + self.avgpool = AvgPool() + self.decode4 = BottleneckBlock(feature_channels[3]) + self.decode3 = UpsamplingBlock(feature_channels[3], + feature_channels[2], 3, + decoder_channels[0]) + self.sc3 = scSEblock(decoder_channels[0]) + self.decode2 = UpsamplingBlock(decoder_channels[0], + feature_channels[1], 3, + decoder_channels[1]) + self.sc2 = scSEblock(decoder_channels[1]) + self.decode1 = UpsamplingBlock(decoder_channels[1], + feature_channels[0], 3, + decoder_channels[2]) + self.sc1 = scSEblock(decoder_channels[2]) + self.out0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) + + self.crosslevel1 = crossfeature(feature_channels[3], + feature_channels[1]) + self.crosslevel2 = crossfeature(feature_channels[2], + feature_channels[0]) + + def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, + f4: Tensor, r1: Optional[Tensor], r2: Optional[Tensor], + r3: Optional[Tensor], r4: Optional[Tensor]): + s2, s3, s4 = self.avgpool(s0) + x4, r4 = self.decode4(f4, r4) + x3, r3 = self.decode3(x4, f3, s4, r3) + x3 = self.sc3(x3) + f2 = self.crosslevel1(f4, f2) + x2, r2 = self.decode2(x3, f2, s3, r2) + x2 = self.sc2(x2) + f1 = self.crosslevel2(f3, f1) + x1, r1 = self.decode1(x2, f1, s2, r1) + x1 = self.sc1(x1) + out = self.out0(x1, s0) + return out, r1, r2, r3, r4 + + +class AvgPool(nn.Module): + + def __init__(self): + super().__init__() + self.avgpool = nn.AvgPool2d( + 2, 2, count_include_pad=False, ceil_mode=True) + + def forward_single_frame(self, s0): + s1 = self.avgpool(s0) + s2 = self.avgpool(s1) + s3 = self.avgpool(s2) + return s1, s2, s3 + + def forward_time_series(self, s0): + B, T = s0.shape[:2] + s0 = s0.flatten(0, 1) + s1, s2, s3 = self.forward_single_frame(s0) + s1 = s1.unflatten(0, (B, T)) + s2 = s2.unflatten(0, (B, T)) + s3 = s3.unflatten(0, (B, T)) + return s1, s2, s3 + + def forward(self, s0): + if s0.ndim == 5: + return self.forward_time_series(s0) + else: + return self.forward_single_frame(s0) + + +class crossfeature(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.avg = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False) + + def forward_single_frame(self, x1, x2): + b, c, _, _ = x1.size() + x1 = self.avg(x1).view(b, c, 1, 1) + x1 = self.conv(x1) + x1 = torch.sigmoid(x1) + x2 = x1 * x2 + return x2 + + def forward_time_series(self, x1, x2): + b, t = x1.shape[:2] + x1 = x1.flatten(0, 1) + x2 = x2.flatten(0, 1) + x2 = self.forward_single_frame(x1, x2) + return x2.unflatten(0, (b, t)) + + def forward(self, x1, x2): + if x1.ndim == 5: + return self.forward_time_series(x1, x2) + else: + return self.forward_single_frame(x1, x2) + + +class BottleneckBlock(nn.Module): + + def __init__(self, channels): + super().__init__() + self.channels = channels + self.gru = GRU(channels // 2) + + def forward(self, x, r): + a, b = x.split(self.channels // 2, dim=-3) + b, r = self.gru(b, r) + x = torch.cat([a, b], dim=-3) + return x, r + + +class UpsamplingBlock(nn.Module): + + def __init__(self, in_channels, skip_channels, src_channels, out_channels): + super().__init__() + self.out_channels = out_channels + self.upsample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) + self.shortcut = nn.Sequential( + nn.Conv2d(skip_channels, in_channels, 3, 1, 1, bias=False), + nn.GroupNorm(in_channels // 4, in_channels), hswish()) + self.att_skip = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False), + nn.Sigmoid()) + self.conv = nn.Sequential( + nn.Conv2d( + in_channels + in_channels + src_channels, + out_channels, + 3, + 1, + 1, + bias=False), + nn.GroupNorm(out_channels // 4, out_channels), + hswish(), + ) + self.gru = GRU(out_channels // 2) + + def forward_single_frame(self, x, f, s, r: Optional[Tensor]): + x = self.upsample(x) + x = x[:, :, :s.size(2), :s.size(3)] + att = self.att_skip(x) + f = self.shortcut(f) + f = att * f + x = torch.cat([x, f, s], dim=1) + x = self.conv(x) + a, b = x.split(self.out_channels // 2, dim=1) + b, r = self.gru(b, r) + x = torch.cat([a, b], dim=1) + return x, r + + def forward_time_series(self, x, f, s, r: Optional[Tensor]): + B, T, _, H, W = s.shape + x = x.flatten(0, 1) + f = f.flatten(0, 1) + s = s.flatten(0, 1) + x = self.upsample(x) + x = x[:, :, :H, :W] + f = self.shortcut(f) + att = self.att_skip(x) + f = att * f + x = torch.cat([x, f, s], dim=1) + x = self.conv(x) + x = x.unflatten(0, (B, T)) + a, b = x.split(self.out_channels // 2, dim=2) + b, r = self.gru(b, r) + x = torch.cat([a, b], dim=2) + return x, r + + def forward(self, x, f, s, r: Optional[Tensor]): + if x.ndim == 5: + return self.forward_time_series(x, f, s, r) + else: + return self.forward_single_frame(x, f, s, r) + + +class OutputBlock(nn.Module): + + def __init__(self, in_channels, src_channels, out_channels): + super().__init__() + self.upsample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) + self.conv = nn.Sequential( + nn.Conv2d( + in_channels + src_channels, out_channels, 3, 1, 1, bias=False), + nn.GroupNorm(out_channels // 2, out_channels), + hswish(), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.GroupNorm(out_channels // 2, out_channels), + hswish(), + ) + + def forward_single_frame(self, x, s): + x = self.upsample(x) + x = x[:, :, :s.size(2), :s.size(3)] + x = torch.cat([x, s], dim=1) + x = self.conv(x) + return x + + def forward_time_series(self, x, s): + B, T, _, H, W = s.shape + x = x.flatten(0, 1) + s = s.flatten(0, 1) + x = self.upsample(x) + x = x[:, :, :H, :W] + x = torch.cat([x, s], dim=1) + x = self.conv(x) + x = x.unflatten(0, (B, T)) + return x + + def forward(self, x, s): + if x.ndim == 5: + return self.forward_time_series(x, s) + else: + return self.forward_single_frame(x, s) + + +class Projection(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 1) + + def forward_single_frame(self, x): + return self.conv(x) + + def forward_time_series(self, x): + B, T = x.shape[:2] + return self.conv(x.flatten(0, 1)).unflatten(0, (B, T)) + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) + + +class GRU(nn.Module): + + def __init__(self, channels, kernel_size=3, padding=1): + super().__init__() + self.channels = channels + self.ih = nn.Conv2d( + channels * 2, channels * 2, kernel_size, padding=padding) + self.act_ih = nn.Sigmoid() + self.hh = nn.Conv2d( + channels * 2, channels, kernel_size, padding=padding) + self.act_hh = nn.Tanh() + + def forward_single_frame(self, x, pre_fea): + fea_ih = self.ih(torch.cat([x, pre_fea], dim=1)) + r, z = self.act_ih(fea_ih).split(self.channels, dim=1) + fea_hh = self.hh(torch.cat([x, r * pre_fea], dim=1)) + c = self.act_hh(fea_hh) + fea_gru = (1 - z) * pre_fea + z * c + return fea_gru, fea_gru + + def forward_time_series(self, x, pre_fea): + o = [] + for xt in x.unbind(dim=1): + ot, pre_fea = self.forward_single_frame(xt, pre_fea) + o.append(ot) + o = torch.stack(o, dim=1) + return o, pre_fea + + def forward(self, x, pre_fea): + if pre_fea is None: + pre_fea = torch.zeros( + (x.size(0), x.size(-3), x.size(-2), x.size(-1)), + device=x.device, + dtype=x.dtype) + + if x.ndim == 5: + return self.forward_time_series(x, pre_fea) + else: + return self.forward_single_frame(x, pre_fea) diff --git a/modelscope/models/cv/video_human_matting/models/deep_guided_filter.py b/modelscope/models/cv/video_human_matting/models/deep_guided_filter.py new file mode 100644 index 00000000..c0081026 --- /dev/null +++ b/modelscope/models/cv/video_human_matting/models/deep_guided_filter.py @@ -0,0 +1,64 @@ +""" +Part of the implementation is borrowed and modified from DeepGuidedFilter +publicly available at +""" +import torch +from torch import nn +from torch.nn import functional as F + + +class DeepGuidedFilterRefiner(nn.Module): + + def __init__(self, hid_channels=16): + super().__init__() + self.box_filter = nn.Conv2d( + 4, 4, kernel_size=3, padding=1, bias=False, groups=4) + self.box_filter.weight.data[...] = 1 / 9 + self.conv = nn.Sequential( + nn.Conv2d( + 4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(hid_channels), nn.ReLU(True), + nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(hid_channels), nn.ReLU(True), + nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)) + + def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, + base_hid): + fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1) + base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1) + base_y = torch.cat([base_fgr, base_pha], dim=1) + + mean_x = self.box_filter(base_x) + mean_y = self.box_filter(base_y) + cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y + var_x = self.box_filter(base_x * base_x) - mean_x * mean_x + + A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1)) + b = mean_y - A * mean_x + + H, W = fine_src.shape[2:] + A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False) + b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False) + + out = A * fine_x + b + fgr, pha = out.split([3, 1], dim=1) + return fgr, pha + + def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, + base_hid): + B, T = fine_src.shape[:2] + fgr, pha = self.forward_single_frame( + fine_src.flatten(0, 1), base_src.flatten(0, 1), + base_fgr.flatten(0, 1), base_pha.flatten(0, 1), + base_hid.flatten(0, 1)) + fgr = fgr.unflatten(0, (B, T)) + pha = pha.unflatten(0, (B, T)) + return fgr, pha + + def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): + if fine_src.ndim == 5: + return self.forward_time_series(fine_src, base_src, base_fgr, + base_pha, base_hid) + else: + return self.forward_single_frame(fine_src, base_src, base_fgr, + base_pha, base_hid) diff --git a/modelscope/models/cv/video_human_matting/models/effv2.py b/modelscope/models/cv/video_human_matting/models/effv2.py new file mode 100644 index 00000000..8151e3b1 --- /dev/null +++ b/modelscope/models/cv/video_human_matting/models/effv2.py @@ -0,0 +1,177 @@ +""" +Part of the implementation is borrowed and modified from EfficientNetV2 +publicly available at +""" + +import torch +import torch.nn.functional + + +class SiLU(torch.nn.Module): + """ + [https://arxiv.org/pdf/1710.05941.pdf] + """ + + def __init__(self, inplace: bool = False): + super().__init__() + self.silu = torch.nn.SiLU(inplace=inplace) + + def forward(self, x): + return self.silu(x) + + +class Conv(torch.nn.Module): + + def __init__(self, in_ch, out_ch, activation, k=1, s=1, g=1): + super().__init__() + self.conv = torch.nn.Conv2d( + in_ch, out_ch, k, s, k // 2, 1, g, bias=False) + self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.01) + self.silu = activation + + def forward(self, x): + return self.silu(self.norm(self.conv(x))) + + +class SE(torch.nn.Module): + """ + [https://arxiv.org/pdf/1709.01507.pdf] + """ + + def __init__(self, ch, r): + super().__init__() + self.se = torch.nn.Sequential( + torch.nn.Conv2d(ch, ch // (4 * r), 1), torch.nn.SiLU(), + torch.nn.Conv2d(ch // (4 * r), ch, 1), torch.nn.Sigmoid()) + + def forward(self, x): + return x * self.se(x.mean((2, 3), keepdim=True)) + + +class Residual(torch.nn.Module): + """ + [https://arxiv.org/pdf/1801.04381.pdf] + """ + + def __init__(self, in_ch, out_ch, s, r, fused=True): + super().__init__() + identity = torch.nn.Identity() + if fused: + if r == 1: + features = [Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s)] + else: + features = [ + Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s), + Conv(r * in_ch, out_ch, identity) + ] + else: + if r == 1: + features = [ + Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s, + r * in_ch), + SE(r * in_ch, r), + Conv(r * in_ch, out_ch, identity) + ] + else: + features = [ + Conv(in_ch, r * in_ch, torch.nn.SiLU()), + Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s, + r * in_ch), + SE(r * in_ch, r), + Conv(r * in_ch, out_ch, identity) + ] + self.add = s == 1 and in_ch == out_ch + self.res = torch.nn.Sequential(*features) + + def forward(self, x): + return x + self.res(x) if self.add else self.res(x) + + +class EfficientNet(torch.nn.Module): + + def __init__(self, pretrained: bool = False): + super().__init__() + gate_fn = [True, False] + filters = [24, 48, 64, 128, 160, 256] + feature = [Conv(3, filters[0], torch.nn.SiLU(), 3, 2)] + for i in range(2): + if i == 0: + feature.append( + Residual(filters[0], filters[0], 1, 1, gate_fn[0])) + else: + feature.append( + Residual(filters[0], filters[0], 1, 1, gate_fn[0])) + + for i in range(4): + if i == 0: + feature.append( + Residual(filters[0], filters[1], 2, 4, gate_fn[0])) + else: + feature.append( + Residual(filters[1], filters[1], 1, 4, gate_fn[0])) + + for i in range(4): + if i == 0: + feature.append( + Residual(filters[1], filters[2], 2, 4, gate_fn[0])) + else: + feature.append( + Residual(filters[2], filters[2], 1, 4, gate_fn[0])) + + for i in range(6): + if i == 0: + feature.append( + Residual(filters[2], filters[3], 2, 4, gate_fn[1])) + else: + feature.append( + Residual(filters[3], filters[3], 1, 4, gate_fn[1])) + + for i in range(9): + if i == 0: + feature.append( + Residual(filters[3], filters[4], 1, 6, gate_fn[1])) + else: + feature.append( + Residual(filters[4], filters[4], 1, 6, gate_fn[1])) + + self.feature = torch.nn.Sequential(*feature) + + def forward_single_frame(self, x): + x = self.feature[0](x) + x = self.feature[1](x) + x = self.feature[2](x) + f1 = x # 1/2 24 + for i in range(4): + x = self.feature[i + 3](x) + f2 = x # 1/4 48 + for i in range(4): + x = self.feature[i + 7](x) + f3 = x # 1/8 64 + for i in range(6): + x = self.feature[i + 11](x) + for i in range(9): + x = self.feature[i + 17](x) + f5 = x # 1/16 160 + return [f1, f2, f3, f5] + + def forward_time_series(self, x): + B, T = x.shape[:2] + features = self.forward_single_frame(x.flatten(0, 1)) + features = [f.unflatten(0, (B, T)) for f in features] + return features + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) + + def export(self): + for m in self.modules(): + if type(m) is Conv and hasattr(m, 'silu'): + if isinstance(m.silu, torch.nn.SiLU): + m.silu = SiLU() + if type(m) is SE: + if isinstance(m.se[1], torch.nn.SiLU): + m.se[1] = SiLU() + return self diff --git a/modelscope/models/cv/video_human_matting/models/lraspp.py b/modelscope/models/cv/video_human_matting/models/lraspp.py new file mode 100644 index 00000000..234b81de --- /dev/null +++ b/modelscope/models/cv/video_human_matting/models/lraspp.py @@ -0,0 +1,94 @@ +""" +Part of the implementation is borrowed and modified from Deeplab v3 +publicly available at +""" +import torch +from torch import nn + + +class ASP_OC_Module(nn.Module): + + def __init__(self, features, out_features=96, dilations=(2, 4, 8)): + super(ASP_OC_Module, self).__init__() + self.conv2 = nn.Sequential( + nn.Conv2d( + features, + out_features, + kernel_size=1, + padding=0, + dilation=1, + bias=False), nn.BatchNorm2d(out_features)) + self.conv3 = nn.Sequential( + nn.Conv2d( + features, + out_features, + kernel_size=3, + padding=dilations[0], + dilation=dilations[0], + bias=False), nn.BatchNorm2d(out_features)) + self.conv4 = nn.Sequential( + nn.Conv2d( + features, + out_features, + kernel_size=3, + padding=dilations[1], + dilation=dilations[1], + bias=False), nn.BatchNorm2d(out_features)) + self.conv5 = nn.Sequential( + nn.Conv2d( + features, + out_features, + kernel_size=3, + padding=dilations[2], + dilation=dilations[2], + bias=False), nn.BatchNorm2d(out_features)) + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d( + out_features * 4, + out_features * 2, + kernel_size=1, + padding=0, + dilation=1, + bias=False), nn.InstanceNorm2d(out_features * 2), + nn.Dropout2d(0.05)) + + def _cat_each(self, feat1, feat2, feat3, feat4, feat5): + assert (len(feat1) == len(feat2)) + z = [] + for i in range(len(feat1)): + z.append( + torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), + 1)) + return z + + def forward(self, x): + _, _, h, w = x.size() + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat2, feat3, feat4, feat5), 1) + output = self.conv_bn_dropout(out) + return output + + +class LRASPP(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.aspp = ASP_OC_Module(in_channels, out_channels) + + def forward_single_frame(self, x): + return self.aspp(x) + + def forward_time_series(self, x): + B, T = x.shape[:2] + x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T)) + return x + + def forward(self, x): + if x.ndim == 5: + return self.forward_time_series(x) + else: + return self.forward_single_frame(x) diff --git a/modelscope/models/cv/video_human_matting/models/matting.py b/modelscope/models/cv/video_human_matting/models/matting.py new file mode 100644 index 00000000..95cce15f --- /dev/null +++ b/modelscope/models/cv/video_human_matting/models/matting.py @@ -0,0 +1,67 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.nn import functional as F + +from .decoder import Projection, RecurrentDecoder +from .deep_guided_filter import DeepGuidedFilterRefiner +from .effv2 import EfficientNet +from .lraspp import LRASPP + + +class MattingNetwork(torch.nn.Module): + + def __init__(self, pretrained_backbone: bool = False): + super().__init__() + self.backbone = EfficientNet(pretrained_backbone) + self.aspp = LRASPP(160, 64) + self.decoder = RecurrentDecoder([24, 48, 64, 128], [64, 32, 24, 16]) + self.project_mat = Projection(16, 4) + self.project_seg = Projection(16, 1) + self.refiner = DeepGuidedFilterRefiner() + + def forward(self, + src: Tensor, + r0: Optional[Tensor] = None, + r1: Optional[Tensor] = None, + r2: Optional[Tensor] = None, + r3: Optional[Tensor] = None, + downsample_ratio: float = 1, + segmentation_pass: bool = False): + + if downsample_ratio != 1: + src_sm = self._interpolate(src, scale_factor=downsample_ratio) + else: + src_sm = src + + f1, f2, f3, f4 = self.backbone(src_sm) + f4 = self.aspp(f4) + hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r0, r1, r2, r3) + + if not segmentation_pass: + fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) + if downsample_ratio != 1: + _, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) + pha = pha.clamp(0., 1.) + return [pha, *rec] + else: + seg = self.project_seg(hid) + return [seg, *rec] + + def _interpolate(self, x: Tensor, scale_factor: float): + if x.ndim == 5: + B, T = x.shape[:2] + x = F.interpolate( + x.flatten(0, 1), + scale_factor=scale_factor, + mode='bilinear', + align_corners=False) + x = x.unflatten(0, (B, T)) + else: + x = F.interpolate( + x, + scale_factor=scale_factor, + mode='bilinear', + align_corners=False) + return x diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 94a8d035..acc8035b 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -443,6 +443,12 @@ TASK_OUTPUTS = { Tasks.referring_video_object_segmentation: [OutputKeys.MASKS, OutputKeys.TIMESTAMPS], + # video human matting result for a single video + # { + # "masks": [np.array # 2D array with shape [height, width]] + # } + Tasks.video_human_matting: [OutputKeys.MASKS], + # ============ nlp tasks =================== # text classification result for single sample diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 68d4f0b1..4821c553 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -201,6 +201,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_fft_inpainting_lama'), Tasks.video_inpainting: (Pipelines.video_inpainting, 'damo/cv_video-inpainting'), + Tasks.video_human_matting: (Pipelines.video_human_matting, + 'damo/cv_effnetv2_video-human-matting'), Tasks.human_wholebody_keypoint: (Pipelines.human_wholebody_keypoint, 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), diff --git a/modelscope/pipelines/cv/video_human_matting_pipeline.py b/modelscope/pipelines/cv/video_human_matting_pipeline.py new file mode 100644 index 00000000..b4e6f2ba --- /dev/null +++ b/modelscope/pipelines/cv/video_human_matting_pipeline.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_human_matting import preprocess +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_human_matting, module_name=Pipelines.video_human_matting) +class VideoHumanMattingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video human matting pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + if torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + logger.info('load model done') + + def preprocess(self, input) -> Input: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + video_path = input['video_input_path'] + out_path = input['output_path'] + video_input = cv2.VideoCapture(video_path) + fps = video_input.get(cv2.CAP_PROP_FPS) + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + success, frame = video_input.read() + h, w = frame.shape[:2] + scale = 512 / max(h, w) + video_save = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) + masks = [] + rec = [None] * 4 + self.model = self.model.to(self.device) + logger.info('matting start using ', self.device) + with torch.no_grad(): + while True: + if frame is None: + break + frame_tensor = preprocess(frame) + pha, *rec = self.model.model( + frame_tensor.to(self.device), *rec, downsample_ratio=scale) + com = pha * 255 + com = com.repeat(1, 3, 1, 1) + com = com[0].data.cpu().numpy().transpose(1, 2, + 0).astype(np.uint8) + video_save.write(com) + masks.append(com / 255) + success, frame = video_input.read() + logger.info('matting process done') + video_input.release() + video_save.release() + + return { + OutputKeys.MASKS: masks, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4d585e1a..8f8e2c6f 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -87,6 +87,7 @@ class CVTasks(object): # video segmentation referring_video_object_segmentation = 'referring-video-object-segmentation' + video_human_matting = 'video-human-matting' # video editing video_inpainting = 'video-inpainting' diff --git a/tests/pipelines/test_video_human_matting.py b/tests/pipelines/test_video_human_matting.py new file mode 100644 index 00000000..4b65c1ac --- /dev/null +++ b/tests/pipelines/test_video_human_matting.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import sys +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class VideoHumanMattingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_effnetv2_video-human-matting' + self.video_in = 'data/test/videos/video_matting_test.mp4' + self.video_out = 'matting_out.mp4' + self.input = { + 'video_input_path': self.video_in, + 'output_path': self.video_out, + } + + def pipeline_inference(self, pipeline: Pipeline, input): + result = pipeline(input) + print('video matting over, results:', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + video_human_matting = pipeline( + Tasks.video_human_matting, model=self.model) + self.pipeline_inference(video_human_matting, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_human_matting = pipeline(Tasks.video_human_matting) + self.pipeline_inference(video_human_matting, self.input) + + +if __name__ == '__main__': + unittest.main()