diff --git a/data/test/videos/video_matting_test.mp4 b/data/test/videos/video_matting_test.mp4
new file mode 100644
index 00000000..efdd3cb0
--- /dev/null
+++ b/data/test/videos/video_matting_test.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e4ade7a6b119e20e82a641246199b4b530759166acc1f813d7cefee65b3e1e0
+size 63944943
diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py
index afba99a7..9ee4091f 100644
--- a/modelscope/metainfo.py
+++ b/modelscope/metainfo.py
@@ -52,6 +52,7 @@ class Models(object):
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'
image_body_reshaping = 'image-body-reshaping'
+ video_human_matting = 'video-human-matting'
# EasyCV models
yolox = 'YOLOX'
@@ -230,6 +231,7 @@ class Pipelines(object):
product_segmentation = 'product-segmentation'
image_body_reshaping = 'flow-based-body-reshaping'
referring_video_object_segmentation = 'referring-video-object-segmentation'
+ video_human_matting = 'video-human-matting'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
diff --git a/modelscope/models/cv/video_human_matting/__init__.py b/modelscope/models/cv/video_human_matting/__init__.py
new file mode 100644
index 00000000..7d47317c
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .model import VideoMattingNetwork
+ from .model import preprocess
+
+else:
+ _import_structure = {'model': ['VideoMattingNetwork', 'preprocess']}
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/video_human_matting/model.py b/modelscope/models/cv/video_human_matting/model.py
new file mode 100644
index 00000000..98948051
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/model.py
@@ -0,0 +1,38 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+import torch
+import torchvision
+from torch.nn import functional as F
+
+from modelscope.metainfo import Models
+from modelscope.models.base import Tensor, TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.cv.video_human_matting.models import MattingNetwork
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+
+@MODELS.register_module(
+ Tasks.video_human_matting, module_name=Models.video_human_matting)
+class VideoMattingNetwork(TorchModel):
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ super().__init__(model_dir, *args, **kwargs)
+ model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
+ params = torch.load(model_path, map_location='cpu')
+ self.model = MattingNetwork()
+ if 'model_state_dict' in params.keys():
+ params = params['model_state_dict']
+ self.model.load_state_dict(params, strict=True)
+ self.model.eval()
+
+
+def preprocess(image):
+ frame_np = np.float32(image) / 255.0
+ frame_np = frame_np.transpose(2, 0, 1)
+ frame_tensor = torch.from_numpy(frame_np)
+ image_tensor = frame_tensor[None, :, :, :]
+ return image_tensor
diff --git a/modelscope/models/cv/video_human_matting/models/__init__.py b/modelscope/models/cv/video_human_matting/models/__init__.py
new file mode 100644
index 00000000..471f0308
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/models/__init__.py
@@ -0,0 +1 @@
+from .matting import MattingNetwork
diff --git a/modelscope/models/cv/video_human_matting/models/decoder.py b/modelscope/models/cv/video_human_matting/models/decoder.py
new file mode 100644
index 00000000..ba82aa90
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/models/decoder.py
@@ -0,0 +1,330 @@
+"""
+Part of the implementation is borrowed from paper RVM
+paper publicly available at
+"""
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+
+
+class hswish(nn.Module):
+
+ def forward(self, x):
+ return torch.nn.Hardswish(inplace=True)(x)
+
+
+class scSEblock(nn.Module):
+
+ def __init__(self, out):
+ super().__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(out, int(out / 2), 3, 1, 1),
+ nn.GroupNorm(out // 8, int(out / 2)), hswish())
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(int(out / 2), out, 1, 1, 0),
+ nn.GroupNorm(out // 4, out),
+ )
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+
+ def forward_single(self, x):
+ b, c, _, _ = x.size()
+ x2 = self.avgpool(x).view(b, c, 1, 1)
+ x2 = self.conv1(x2)
+ x2 = self.conv2(x2)
+ x2 = torch.sigmoid(x2)
+ out = x2 * x
+ return out
+
+ def forward_time(self, x):
+ B, T, _, H, W = x.shape
+ x = x.flatten(0, 1)
+ out = self.forward_single(x)
+ out = out.unflatten(0, (B, T))
+ return out
+
+ def forward(self, x):
+ if x.ndim == 5:
+ return self.forward_time(x)
+ else:
+ return self.forward_single(x)
+
+
+class RecurrentDecoder(nn.Module):
+
+ def __init__(self, feature_channels, decoder_channels):
+ super().__init__()
+ self.avgpool = AvgPool()
+ self.decode4 = BottleneckBlock(feature_channels[3])
+ self.decode3 = UpsamplingBlock(feature_channels[3],
+ feature_channels[2], 3,
+ decoder_channels[0])
+ self.sc3 = scSEblock(decoder_channels[0])
+ self.decode2 = UpsamplingBlock(decoder_channels[0],
+ feature_channels[1], 3,
+ decoder_channels[1])
+ self.sc2 = scSEblock(decoder_channels[1])
+ self.decode1 = UpsamplingBlock(decoder_channels[1],
+ feature_channels[0], 3,
+ decoder_channels[2])
+ self.sc1 = scSEblock(decoder_channels[2])
+ self.out0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
+
+ self.crosslevel1 = crossfeature(feature_channels[3],
+ feature_channels[1])
+ self.crosslevel2 = crossfeature(feature_channels[2],
+ feature_channels[0])
+
+ def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor,
+ f4: Tensor, r1: Optional[Tensor], r2: Optional[Tensor],
+ r3: Optional[Tensor], r4: Optional[Tensor]):
+ s2, s3, s4 = self.avgpool(s0)
+ x4, r4 = self.decode4(f4, r4)
+ x3, r3 = self.decode3(x4, f3, s4, r3)
+ x3 = self.sc3(x3)
+ f2 = self.crosslevel1(f4, f2)
+ x2, r2 = self.decode2(x3, f2, s3, r2)
+ x2 = self.sc2(x2)
+ f1 = self.crosslevel2(f3, f1)
+ x1, r1 = self.decode1(x2, f1, s2, r1)
+ x1 = self.sc1(x1)
+ out = self.out0(x1, s0)
+ return out, r1, r2, r3, r4
+
+
+class AvgPool(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.avgpool = nn.AvgPool2d(
+ 2, 2, count_include_pad=False, ceil_mode=True)
+
+ def forward_single_frame(self, s0):
+ s1 = self.avgpool(s0)
+ s2 = self.avgpool(s1)
+ s3 = self.avgpool(s2)
+ return s1, s2, s3
+
+ def forward_time_series(self, s0):
+ B, T = s0.shape[:2]
+ s0 = s0.flatten(0, 1)
+ s1, s2, s3 = self.forward_single_frame(s0)
+ s1 = s1.unflatten(0, (B, T))
+ s2 = s2.unflatten(0, (B, T))
+ s3 = s3.unflatten(0, (B, T))
+ return s1, s2, s3
+
+ def forward(self, s0):
+ if s0.ndim == 5:
+ return self.forward_time_series(s0)
+ else:
+ return self.forward_single_frame(s0)
+
+
+class crossfeature(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.avg = nn.AdaptiveAvgPool2d(1)
+ self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
+
+ def forward_single_frame(self, x1, x2):
+ b, c, _, _ = x1.size()
+ x1 = self.avg(x1).view(b, c, 1, 1)
+ x1 = self.conv(x1)
+ x1 = torch.sigmoid(x1)
+ x2 = x1 * x2
+ return x2
+
+ def forward_time_series(self, x1, x2):
+ b, t = x1.shape[:2]
+ x1 = x1.flatten(0, 1)
+ x2 = x2.flatten(0, 1)
+ x2 = self.forward_single_frame(x1, x2)
+ return x2.unflatten(0, (b, t))
+
+ def forward(self, x1, x2):
+ if x1.ndim == 5:
+ return self.forward_time_series(x1, x2)
+ else:
+ return self.forward_single_frame(x1, x2)
+
+
+class BottleneckBlock(nn.Module):
+
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.gru = GRU(channels // 2)
+
+ def forward(self, x, r):
+ a, b = x.split(self.channels // 2, dim=-3)
+ b, r = self.gru(b, r)
+ x = torch.cat([a, b], dim=-3)
+ return x, r
+
+
+class UpsamplingBlock(nn.Module):
+
+ def __init__(self, in_channels, skip_channels, src_channels, out_channels):
+ super().__init__()
+ self.out_channels = out_channels
+ self.upsample = nn.Upsample(
+ scale_factor=2, mode='bilinear', align_corners=False)
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(skip_channels, in_channels, 3, 1, 1, bias=False),
+ nn.GroupNorm(in_channels // 4, in_channels), hswish())
+ self.att_skip = nn.Sequential(
+ nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False),
+ nn.Sigmoid())
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels + in_channels + src_channels,
+ out_channels,
+ 3,
+ 1,
+ 1,
+ bias=False),
+ nn.GroupNorm(out_channels // 4, out_channels),
+ hswish(),
+ )
+ self.gru = GRU(out_channels // 2)
+
+ def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
+ x = self.upsample(x)
+ x = x[:, :, :s.size(2), :s.size(3)]
+ att = self.att_skip(x)
+ f = self.shortcut(f)
+ f = att * f
+ x = torch.cat([x, f, s], dim=1)
+ x = self.conv(x)
+ a, b = x.split(self.out_channels // 2, dim=1)
+ b, r = self.gru(b, r)
+ x = torch.cat([a, b], dim=1)
+ return x, r
+
+ def forward_time_series(self, x, f, s, r: Optional[Tensor]):
+ B, T, _, H, W = s.shape
+ x = x.flatten(0, 1)
+ f = f.flatten(0, 1)
+ s = s.flatten(0, 1)
+ x = self.upsample(x)
+ x = x[:, :, :H, :W]
+ f = self.shortcut(f)
+ att = self.att_skip(x)
+ f = att * f
+ x = torch.cat([x, f, s], dim=1)
+ x = self.conv(x)
+ x = x.unflatten(0, (B, T))
+ a, b = x.split(self.out_channels // 2, dim=2)
+ b, r = self.gru(b, r)
+ x = torch.cat([a, b], dim=2)
+ return x, r
+
+ def forward(self, x, f, s, r: Optional[Tensor]):
+ if x.ndim == 5:
+ return self.forward_time_series(x, f, s, r)
+ else:
+ return self.forward_single_frame(x, f, s, r)
+
+
+class OutputBlock(nn.Module):
+
+ def __init__(self, in_channels, src_channels, out_channels):
+ super().__init__()
+ self.upsample = nn.Upsample(
+ scale_factor=2, mode='bilinear', align_corners=False)
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
+ nn.GroupNorm(out_channels // 2, out_channels),
+ hswish(),
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
+ nn.GroupNorm(out_channels // 2, out_channels),
+ hswish(),
+ )
+
+ def forward_single_frame(self, x, s):
+ x = self.upsample(x)
+ x = x[:, :, :s.size(2), :s.size(3)]
+ x = torch.cat([x, s], dim=1)
+ x = self.conv(x)
+ return x
+
+ def forward_time_series(self, x, s):
+ B, T, _, H, W = s.shape
+ x = x.flatten(0, 1)
+ s = s.flatten(0, 1)
+ x = self.upsample(x)
+ x = x[:, :, :H, :W]
+ x = torch.cat([x, s], dim=1)
+ x = self.conv(x)
+ x = x.unflatten(0, (B, T))
+ return x
+
+ def forward(self, x, s):
+ if x.ndim == 5:
+ return self.forward_time_series(x, s)
+ else:
+ return self.forward_single_frame(x, s)
+
+
+class Projection(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, 1)
+
+ def forward_single_frame(self, x):
+ return self.conv(x)
+
+ def forward_time_series(self, x):
+ B, T = x.shape[:2]
+ return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
+
+ def forward(self, x):
+ if x.ndim == 5:
+ return self.forward_time_series(x)
+ else:
+ return self.forward_single_frame(x)
+
+
+class GRU(nn.Module):
+
+ def __init__(self, channels, kernel_size=3, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.ih = nn.Conv2d(
+ channels * 2, channels * 2, kernel_size, padding=padding)
+ self.act_ih = nn.Sigmoid()
+ self.hh = nn.Conv2d(
+ channels * 2, channels, kernel_size, padding=padding)
+ self.act_hh = nn.Tanh()
+
+ def forward_single_frame(self, x, pre_fea):
+ fea_ih = self.ih(torch.cat([x, pre_fea], dim=1))
+ r, z = self.act_ih(fea_ih).split(self.channels, dim=1)
+ fea_hh = self.hh(torch.cat([x, r * pre_fea], dim=1))
+ c = self.act_hh(fea_hh)
+ fea_gru = (1 - z) * pre_fea + z * c
+ return fea_gru, fea_gru
+
+ def forward_time_series(self, x, pre_fea):
+ o = []
+ for xt in x.unbind(dim=1):
+ ot, pre_fea = self.forward_single_frame(xt, pre_fea)
+ o.append(ot)
+ o = torch.stack(o, dim=1)
+ return o, pre_fea
+
+ def forward(self, x, pre_fea):
+ if pre_fea is None:
+ pre_fea = torch.zeros(
+ (x.size(0), x.size(-3), x.size(-2), x.size(-1)),
+ device=x.device,
+ dtype=x.dtype)
+
+ if x.ndim == 5:
+ return self.forward_time_series(x, pre_fea)
+ else:
+ return self.forward_single_frame(x, pre_fea)
diff --git a/modelscope/models/cv/video_human_matting/models/deep_guided_filter.py b/modelscope/models/cv/video_human_matting/models/deep_guided_filter.py
new file mode 100644
index 00000000..c0081026
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/models/deep_guided_filter.py
@@ -0,0 +1,64 @@
+"""
+Part of the implementation is borrowed and modified from DeepGuidedFilter
+publicly available at
+"""
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class DeepGuidedFilterRefiner(nn.Module):
+
+ def __init__(self, hid_channels=16):
+ super().__init__()
+ self.box_filter = nn.Conv2d(
+ 4, 4, kernel_size=3, padding=1, bias=False, groups=4)
+ self.box_filter.weight.data[...] = 1 / 9
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ 4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(hid_channels), nn.ReLU(True),
+ nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(hid_channels), nn.ReLU(True),
+ nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True))
+
+ def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha,
+ base_hid):
+ fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
+ base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
+ base_y = torch.cat([base_fgr, base_pha], dim=1)
+
+ mean_x = self.box_filter(base_x)
+ mean_y = self.box_filter(base_y)
+ cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
+ var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
+
+ A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
+ b = mean_y - A * mean_x
+
+ H, W = fine_src.shape[2:]
+ A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
+ b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
+
+ out = A * fine_x + b
+ fgr, pha = out.split([3, 1], dim=1)
+ return fgr, pha
+
+ def forward_time_series(self, fine_src, base_src, base_fgr, base_pha,
+ base_hid):
+ B, T = fine_src.shape[:2]
+ fgr, pha = self.forward_single_frame(
+ fine_src.flatten(0, 1), base_src.flatten(0, 1),
+ base_fgr.flatten(0, 1), base_pha.flatten(0, 1),
+ base_hid.flatten(0, 1))
+ fgr = fgr.unflatten(0, (B, T))
+ pha = pha.unflatten(0, (B, T))
+ return fgr, pha
+
+ def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
+ if fine_src.ndim == 5:
+ return self.forward_time_series(fine_src, base_src, base_fgr,
+ base_pha, base_hid)
+ else:
+ return self.forward_single_frame(fine_src, base_src, base_fgr,
+ base_pha, base_hid)
diff --git a/modelscope/models/cv/video_human_matting/models/effv2.py b/modelscope/models/cv/video_human_matting/models/effv2.py
new file mode 100644
index 00000000..8151e3b1
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/models/effv2.py
@@ -0,0 +1,177 @@
+"""
+Part of the implementation is borrowed and modified from EfficientNetV2
+publicly available at
+"""
+
+import torch
+import torch.nn.functional
+
+
+class SiLU(torch.nn.Module):
+ """
+ [https://arxiv.org/pdf/1710.05941.pdf]
+ """
+
+ def __init__(self, inplace: bool = False):
+ super().__init__()
+ self.silu = torch.nn.SiLU(inplace=inplace)
+
+ def forward(self, x):
+ return self.silu(x)
+
+
+class Conv(torch.nn.Module):
+
+ def __init__(self, in_ch, out_ch, activation, k=1, s=1, g=1):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_ch, out_ch, k, s, k // 2, 1, g, bias=False)
+ self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.01)
+ self.silu = activation
+
+ def forward(self, x):
+ return self.silu(self.norm(self.conv(x)))
+
+
+class SE(torch.nn.Module):
+ """
+ [https://arxiv.org/pdf/1709.01507.pdf]
+ """
+
+ def __init__(self, ch, r):
+ super().__init__()
+ self.se = torch.nn.Sequential(
+ torch.nn.Conv2d(ch, ch // (4 * r), 1), torch.nn.SiLU(),
+ torch.nn.Conv2d(ch // (4 * r), ch, 1), torch.nn.Sigmoid())
+
+ def forward(self, x):
+ return x * self.se(x.mean((2, 3), keepdim=True))
+
+
+class Residual(torch.nn.Module):
+ """
+ [https://arxiv.org/pdf/1801.04381.pdf]
+ """
+
+ def __init__(self, in_ch, out_ch, s, r, fused=True):
+ super().__init__()
+ identity = torch.nn.Identity()
+ if fused:
+ if r == 1:
+ features = [Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s)]
+ else:
+ features = [
+ Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s),
+ Conv(r * in_ch, out_ch, identity)
+ ]
+ else:
+ if r == 1:
+ features = [
+ Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s,
+ r * in_ch),
+ SE(r * in_ch, r),
+ Conv(r * in_ch, out_ch, identity)
+ ]
+ else:
+ features = [
+ Conv(in_ch, r * in_ch, torch.nn.SiLU()),
+ Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s,
+ r * in_ch),
+ SE(r * in_ch, r),
+ Conv(r * in_ch, out_ch, identity)
+ ]
+ self.add = s == 1 and in_ch == out_ch
+ self.res = torch.nn.Sequential(*features)
+
+ def forward(self, x):
+ return x + self.res(x) if self.add else self.res(x)
+
+
+class EfficientNet(torch.nn.Module):
+
+ def __init__(self, pretrained: bool = False):
+ super().__init__()
+ gate_fn = [True, False]
+ filters = [24, 48, 64, 128, 160, 256]
+ feature = [Conv(3, filters[0], torch.nn.SiLU(), 3, 2)]
+ for i in range(2):
+ if i == 0:
+ feature.append(
+ Residual(filters[0], filters[0], 1, 1, gate_fn[0]))
+ else:
+ feature.append(
+ Residual(filters[0], filters[0], 1, 1, gate_fn[0]))
+
+ for i in range(4):
+ if i == 0:
+ feature.append(
+ Residual(filters[0], filters[1], 2, 4, gate_fn[0]))
+ else:
+ feature.append(
+ Residual(filters[1], filters[1], 1, 4, gate_fn[0]))
+
+ for i in range(4):
+ if i == 0:
+ feature.append(
+ Residual(filters[1], filters[2], 2, 4, gate_fn[0]))
+ else:
+ feature.append(
+ Residual(filters[2], filters[2], 1, 4, gate_fn[0]))
+
+ for i in range(6):
+ if i == 0:
+ feature.append(
+ Residual(filters[2], filters[3], 2, 4, gate_fn[1]))
+ else:
+ feature.append(
+ Residual(filters[3], filters[3], 1, 4, gate_fn[1]))
+
+ for i in range(9):
+ if i == 0:
+ feature.append(
+ Residual(filters[3], filters[4], 1, 6, gate_fn[1]))
+ else:
+ feature.append(
+ Residual(filters[4], filters[4], 1, 6, gate_fn[1]))
+
+ self.feature = torch.nn.Sequential(*feature)
+
+ def forward_single_frame(self, x):
+ x = self.feature[0](x)
+ x = self.feature[1](x)
+ x = self.feature[2](x)
+ f1 = x # 1/2 24
+ for i in range(4):
+ x = self.feature[i + 3](x)
+ f2 = x # 1/4 48
+ for i in range(4):
+ x = self.feature[i + 7](x)
+ f3 = x # 1/8 64
+ for i in range(6):
+ x = self.feature[i + 11](x)
+ for i in range(9):
+ x = self.feature[i + 17](x)
+ f5 = x # 1/16 160
+ return [f1, f2, f3, f5]
+
+ def forward_time_series(self, x):
+ B, T = x.shape[:2]
+ features = self.forward_single_frame(x.flatten(0, 1))
+ features = [f.unflatten(0, (B, T)) for f in features]
+ return features
+
+ def forward(self, x):
+ if x.ndim == 5:
+ return self.forward_time_series(x)
+ else:
+ return self.forward_single_frame(x)
+
+ def export(self):
+ for m in self.modules():
+ if type(m) is Conv and hasattr(m, 'silu'):
+ if isinstance(m.silu, torch.nn.SiLU):
+ m.silu = SiLU()
+ if type(m) is SE:
+ if isinstance(m.se[1], torch.nn.SiLU):
+ m.se[1] = SiLU()
+ return self
diff --git a/modelscope/models/cv/video_human_matting/models/lraspp.py b/modelscope/models/cv/video_human_matting/models/lraspp.py
new file mode 100644
index 00000000..234b81de
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/models/lraspp.py
@@ -0,0 +1,94 @@
+"""
+Part of the implementation is borrowed and modified from Deeplab v3
+publicly available at
+"""
+import torch
+from torch import nn
+
+
+class ASP_OC_Module(nn.Module):
+
+ def __init__(self, features, out_features=96, dilations=(2, 4, 8)):
+ super(ASP_OC_Module, self).__init__()
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ padding=0,
+ dilation=1,
+ bias=False), nn.BatchNorm2d(out_features))
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=3,
+ padding=dilations[0],
+ dilation=dilations[0],
+ bias=False), nn.BatchNorm2d(out_features))
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=3,
+ padding=dilations[1],
+ dilation=dilations[1],
+ bias=False), nn.BatchNorm2d(out_features))
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=3,
+ padding=dilations[2],
+ dilation=dilations[2],
+ bias=False), nn.BatchNorm2d(out_features))
+
+ self.conv_bn_dropout = nn.Sequential(
+ nn.Conv2d(
+ out_features * 4,
+ out_features * 2,
+ kernel_size=1,
+ padding=0,
+ dilation=1,
+ bias=False), nn.InstanceNorm2d(out_features * 2),
+ nn.Dropout2d(0.05))
+
+ def _cat_each(self, feat1, feat2, feat3, feat4, feat5):
+ assert (len(feat1) == len(feat2))
+ z = []
+ for i in range(len(feat1)):
+ z.append(
+ torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]),
+ 1))
+ return z
+
+ def forward(self, x):
+ _, _, h, w = x.size()
+ feat2 = self.conv2(x)
+ feat3 = self.conv3(x)
+ feat4 = self.conv4(x)
+ feat5 = self.conv5(x)
+ out = torch.cat((feat2, feat3, feat4, feat5), 1)
+ output = self.conv_bn_dropout(out)
+ return output
+
+
+class LRASPP(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.aspp = ASP_OC_Module(in_channels, out_channels)
+
+ def forward_single_frame(self, x):
+ return self.aspp(x)
+
+ def forward_time_series(self, x):
+ B, T = x.shape[:2]
+ x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T))
+ return x
+
+ def forward(self, x):
+ if x.ndim == 5:
+ return self.forward_time_series(x)
+ else:
+ return self.forward_single_frame(x)
diff --git a/modelscope/models/cv/video_human_matting/models/matting.py b/modelscope/models/cv/video_human_matting/models/matting.py
new file mode 100644
index 00000000..95cce15f
--- /dev/null
+++ b/modelscope/models/cv/video_human_matting/models/matting.py
@@ -0,0 +1,67 @@
+from typing import Optional
+
+import torch
+from torch import Tensor
+from torch.nn import functional as F
+
+from .decoder import Projection, RecurrentDecoder
+from .deep_guided_filter import DeepGuidedFilterRefiner
+from .effv2 import EfficientNet
+from .lraspp import LRASPP
+
+
+class MattingNetwork(torch.nn.Module):
+
+ def __init__(self, pretrained_backbone: bool = False):
+ super().__init__()
+ self.backbone = EfficientNet(pretrained_backbone)
+ self.aspp = LRASPP(160, 64)
+ self.decoder = RecurrentDecoder([24, 48, 64, 128], [64, 32, 24, 16])
+ self.project_mat = Projection(16, 4)
+ self.project_seg = Projection(16, 1)
+ self.refiner = DeepGuidedFilterRefiner()
+
+ def forward(self,
+ src: Tensor,
+ r0: Optional[Tensor] = None,
+ r1: Optional[Tensor] = None,
+ r2: Optional[Tensor] = None,
+ r3: Optional[Tensor] = None,
+ downsample_ratio: float = 1,
+ segmentation_pass: bool = False):
+
+ if downsample_ratio != 1:
+ src_sm = self._interpolate(src, scale_factor=downsample_ratio)
+ else:
+ src_sm = src
+
+ f1, f2, f3, f4 = self.backbone(src_sm)
+ f4 = self.aspp(f4)
+ hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r0, r1, r2, r3)
+
+ if not segmentation_pass:
+ fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
+ if downsample_ratio != 1:
+ _, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
+ pha = pha.clamp(0., 1.)
+ return [pha, *rec]
+ else:
+ seg = self.project_seg(hid)
+ return [seg, *rec]
+
+ def _interpolate(self, x: Tensor, scale_factor: float):
+ if x.ndim == 5:
+ B, T = x.shape[:2]
+ x = F.interpolate(
+ x.flatten(0, 1),
+ scale_factor=scale_factor,
+ mode='bilinear',
+ align_corners=False)
+ x = x.unflatten(0, (B, T))
+ else:
+ x = F.interpolate(
+ x,
+ scale_factor=scale_factor,
+ mode='bilinear',
+ align_corners=False)
+ return x
diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py
index 94a8d035..acc8035b 100644
--- a/modelscope/outputs/outputs.py
+++ b/modelscope/outputs/outputs.py
@@ -443,6 +443,12 @@ TASK_OUTPUTS = {
Tasks.referring_video_object_segmentation:
[OutputKeys.MASKS, OutputKeys.TIMESTAMPS],
+ # video human matting result for a single video
+ # {
+ # "masks": [np.array # 2D array with shape [height, width]]
+ # }
+ Tasks.video_human_matting: [OutputKeys.MASKS],
+
# ============ nlp tasks ===================
# text classification result for single sample
diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py
index 68d4f0b1..4821c553 100644
--- a/modelscope/pipelines/builder.py
+++ b/modelscope/pipelines/builder.py
@@ -201,6 +201,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_fft_inpainting_lama'),
Tasks.video_inpainting: (Pipelines.video_inpainting,
'damo/cv_video-inpainting'),
+ Tasks.video_human_matting: (Pipelines.video_human_matting,
+ 'damo/cv_effnetv2_video-human-matting'),
Tasks.human_wholebody_keypoint:
(Pipelines.human_wholebody_keypoint,
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
diff --git a/modelscope/pipelines/cv/video_human_matting_pipeline.py b/modelscope/pipelines/cv/video_human_matting_pipeline.py
new file mode 100644
index 00000000..b4e6f2ba
--- /dev/null
+++ b/modelscope/pipelines/cv/video_human_matting_pipeline.py
@@ -0,0 +1,77 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os.path as osp
+from typing import Any, Dict
+
+import cv2
+import numpy as np
+import torch
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.cv.video_human_matting import preprocess
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.video_human_matting, module_name=Pipelines.video_human_matting)
+class VideoHumanMattingPipeline(Pipeline):
+
+ def __init__(self, model: str, **kwargs):
+ """
+ use `model` to create a video human matting pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ """
+ super().__init__(model=model, **kwargs)
+ if torch.cuda.is_available():
+ self.device = 'cuda'
+ else:
+ self.device = 'cpu'
+ logger.info('load model done')
+
+ def preprocess(self, input) -> Input:
+ return input
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ video_path = input['video_input_path']
+ out_path = input['output_path']
+ video_input = cv2.VideoCapture(video_path)
+ fps = video_input.get(cv2.CAP_PROP_FPS)
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ success, frame = video_input.read()
+ h, w = frame.shape[:2]
+ scale = 512 / max(h, w)
+ video_save = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
+ masks = []
+ rec = [None] * 4
+ self.model = self.model.to(self.device)
+ logger.info('matting start using ', self.device)
+ with torch.no_grad():
+ while True:
+ if frame is None:
+ break
+ frame_tensor = preprocess(frame)
+ pha, *rec = self.model.model(
+ frame_tensor.to(self.device), *rec, downsample_ratio=scale)
+ com = pha * 255
+ com = com.repeat(1, 3, 1, 1)
+ com = com[0].data.cpu().numpy().transpose(1, 2,
+ 0).astype(np.uint8)
+ video_save.write(com)
+ masks.append(com / 255)
+ success, frame = video_input.read()
+ logger.info('matting process done')
+ video_input.release()
+ video_save.release()
+
+ return {
+ OutputKeys.MASKS: masks,
+ }
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py
index 4d585e1a..8f8e2c6f 100644
--- a/modelscope/utils/constant.py
+++ b/modelscope/utils/constant.py
@@ -87,6 +87,7 @@ class CVTasks(object):
# video segmentation
referring_video_object_segmentation = 'referring-video-object-segmentation'
+ video_human_matting = 'video-human-matting'
# video editing
video_inpainting = 'video-inpainting'
diff --git a/tests/pipelines/test_video_human_matting.py b/tests/pipelines/test_video_human_matting.py
new file mode 100644
index 00000000..4b65c1ac
--- /dev/null
+++ b/tests/pipelines/test_video_human_matting.py
@@ -0,0 +1,39 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import sys
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.pipelines.base import Pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.test_utils import test_level
+
+
+class VideoHumanMattingTest(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.model = 'damo/cv_effnetv2_video-human-matting'
+ self.video_in = 'data/test/videos/video_matting_test.mp4'
+ self.video_out = 'matting_out.mp4'
+ self.input = {
+ 'video_input_path': self.video_in,
+ 'output_path': self.video_out,
+ }
+
+ def pipeline_inference(self, pipeline: Pipeline, input):
+ result = pipeline(input)
+ print('video matting over, results:', result)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ video_human_matting = pipeline(
+ Tasks.video_human_matting, model=self.model)
+ self.pipeline_inference(video_human_matting, self.input)
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_modelhub_default_model(self):
+ video_human_matting = pipeline(Tasks.video_human_matting)
+ self.pipeline_inference(video_human_matting, self.input)
+
+
+if __name__ == '__main__':
+ unittest.main()