Browse Source

add video human matting task code

add video human matting task code
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10839854
master^2
jinmao.yk yingda.chen 3 years ago
parent
commit
d84a1df65a
15 changed files with 922 additions and 0 deletions
  1. +3
    -0
      data/test/videos/video_matting_test.mp4
  2. +2
    -0
      modelscope/metainfo.py
  3. +21
    -0
      modelscope/models/cv/video_human_matting/__init__.py
  4. +38
    -0
      modelscope/models/cv/video_human_matting/model.py
  5. +1
    -0
      modelscope/models/cv/video_human_matting/models/__init__.py
  6. +330
    -0
      modelscope/models/cv/video_human_matting/models/decoder.py
  7. +64
    -0
      modelscope/models/cv/video_human_matting/models/deep_guided_filter.py
  8. +177
    -0
      modelscope/models/cv/video_human_matting/models/effv2.py
  9. +94
    -0
      modelscope/models/cv/video_human_matting/models/lraspp.py
  10. +67
    -0
      modelscope/models/cv/video_human_matting/models/matting.py
  11. +6
    -0
      modelscope/outputs/outputs.py
  12. +2
    -0
      modelscope/pipelines/builder.py
  13. +77
    -0
      modelscope/pipelines/cv/video_human_matting_pipeline.py
  14. +1
    -0
      modelscope/utils/constant.py
  15. +39
    -0
      tests/pipelines/test_video_human_matting.py

+ 3
- 0
data/test/videos/video_matting_test.mp4 View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8e4ade7a6b119e20e82a641246199b4b530759166acc1f813d7cefee65b3e1e0
size 63944943

+ 2
- 0
modelscope/metainfo.py View File

@@ -52,6 +52,7 @@ class Models(object):
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'
image_body_reshaping = 'image-body-reshaping'
video_human_matting = 'video-human-matting'

# EasyCV models
yolox = 'YOLOX'
@@ -230,6 +231,7 @@ class Pipelines(object):
product_segmentation = 'product-segmentation'
image_body_reshaping = 'flow-based-body-reshaping'
referring_video_object_segmentation = 'referring-video-object-segmentation'
video_human_matting = 'video-human-matting'

# nlp tasks
automatic_post_editing = 'automatic-post-editing'


+ 21
- 0
modelscope/models/cv/video_human_matting/__init__.py View File

@@ -0,0 +1,21 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .model import VideoMattingNetwork
from .model import preprocess

else:
_import_structure = {'model': ['VideoMattingNetwork', 'preprocess']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 38
- 0
modelscope/models/cv/video_human_matting/model.py View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Optional

import numpy as np
import torch
import torchvision
from torch.nn import functional as F

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.video_human_matting.models import MattingNetwork
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger


@MODELS.register_module(
Tasks.video_human_matting, module_name=Models.video_human_matting)
class VideoMattingNetwork(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
params = torch.load(model_path, map_location='cpu')
self.model = MattingNetwork()
if 'model_state_dict' in params.keys():
params = params['model_state_dict']
self.model.load_state_dict(params, strict=True)
self.model.eval()


def preprocess(image):
frame_np = np.float32(image) / 255.0
frame_np = frame_np.transpose(2, 0, 1)
frame_tensor = torch.from_numpy(frame_np)
image_tensor = frame_tensor[None, :, :, :]
return image_tensor

+ 1
- 0
modelscope/models/cv/video_human_matting/models/__init__.py View File

@@ -0,0 +1 @@
from .matting import MattingNetwork

+ 330
- 0
modelscope/models/cv/video_human_matting/models/decoder.py View File

@@ -0,0 +1,330 @@
"""
Part of the implementation is borrowed from paper RVM
paper publicly available at <https://arxiv.org/abs/2108.11515/>
"""
from typing import Optional

import torch
from torch import Tensor, nn


class hswish(nn.Module):

def forward(self, x):
return torch.nn.Hardswish(inplace=True)(x)


class scSEblock(nn.Module):

def __init__(self, out):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(out, int(out / 2), 3, 1, 1),
nn.GroupNorm(out // 8, int(out / 2)), hswish())
self.conv2 = nn.Sequential(
nn.Conv2d(int(out / 2), out, 1, 1, 0),
nn.GroupNorm(out // 4, out),
)
self.avgpool = nn.AdaptiveAvgPool2d(1)

def forward_single(self, x):
b, c, _, _ = x.size()
x2 = self.avgpool(x).view(b, c, 1, 1)
x2 = self.conv1(x2)
x2 = self.conv2(x2)
x2 = torch.sigmoid(x2)
out = x2 * x
return out

def forward_time(self, x):
B, T, _, H, W = x.shape
x = x.flatten(0, 1)
out = self.forward_single(x)
out = out.unflatten(0, (B, T))
return out

def forward(self, x):
if x.ndim == 5:
return self.forward_time(x)
else:
return self.forward_single(x)


class RecurrentDecoder(nn.Module):

def __init__(self, feature_channels, decoder_channels):
super().__init__()
self.avgpool = AvgPool()
self.decode4 = BottleneckBlock(feature_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3],
feature_channels[2], 3,
decoder_channels[0])
self.sc3 = scSEblock(decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0],
feature_channels[1], 3,
decoder_channels[1])
self.sc2 = scSEblock(decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1],
feature_channels[0], 3,
decoder_channels[2])
self.sc1 = scSEblock(decoder_channels[2])
self.out0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])

self.crosslevel1 = crossfeature(feature_channels[3],
feature_channels[1])
self.crosslevel2 = crossfeature(feature_channels[2],
feature_channels[0])

def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor,
f4: Tensor, r1: Optional[Tensor], r2: Optional[Tensor],
r3: Optional[Tensor], r4: Optional[Tensor]):
s2, s3, s4 = self.avgpool(s0)
x4, r4 = self.decode4(f4, r4)
x3, r3 = self.decode3(x4, f3, s4, r3)
x3 = self.sc3(x3)
f2 = self.crosslevel1(f4, f2)
x2, r2 = self.decode2(x3, f2, s3, r2)
x2 = self.sc2(x2)
f1 = self.crosslevel2(f3, f1)
x1, r1 = self.decode1(x2, f1, s2, r1)
x1 = self.sc1(x1)
out = self.out0(x1, s0)
return out, r1, r2, r3, r4


class AvgPool(nn.Module):

def __init__(self):
super().__init__()
self.avgpool = nn.AvgPool2d(
2, 2, count_include_pad=False, ceil_mode=True)

def forward_single_frame(self, s0):
s1 = self.avgpool(s0)
s2 = self.avgpool(s1)
s3 = self.avgpool(s2)
return s1, s2, s3

def forward_time_series(self, s0):
B, T = s0.shape[:2]
s0 = s0.flatten(0, 1)
s1, s2, s3 = self.forward_single_frame(s0)
s1 = s1.unflatten(0, (B, T))
s2 = s2.unflatten(0, (B, T))
s3 = s3.unflatten(0, (B, T))
return s1, s2, s3

def forward(self, s0):
if s0.ndim == 5:
return self.forward_time_series(s0)
else:
return self.forward_single_frame(s0)


class crossfeature(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.avg = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)

def forward_single_frame(self, x1, x2):
b, c, _, _ = x1.size()
x1 = self.avg(x1).view(b, c, 1, 1)
x1 = self.conv(x1)
x1 = torch.sigmoid(x1)
x2 = x1 * x2
return x2

def forward_time_series(self, x1, x2):
b, t = x1.shape[:2]
x1 = x1.flatten(0, 1)
x2 = x2.flatten(0, 1)
x2 = self.forward_single_frame(x1, x2)
return x2.unflatten(0, (b, t))

def forward(self, x1, x2):
if x1.ndim == 5:
return self.forward_time_series(x1, x2)
else:
return self.forward_single_frame(x1, x2)


class BottleneckBlock(nn.Module):

def __init__(self, channels):
super().__init__()
self.channels = channels
self.gru = GRU(channels // 2)

def forward(self, x, r):
a, b = x.split(self.channels // 2, dim=-3)
b, r = self.gru(b, r)
x = torch.cat([a, b], dim=-3)
return x, r


class UpsamplingBlock(nn.Module):

def __init__(self, in_channels, skip_channels, src_channels, out_channels):
super().__init__()
self.out_channels = out_channels
self.upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
self.shortcut = nn.Sequential(
nn.Conv2d(skip_channels, in_channels, 3, 1, 1, bias=False),
nn.GroupNorm(in_channels // 4, in_channels), hswish())
self.att_skip = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=False),
nn.Sigmoid())
self.conv = nn.Sequential(
nn.Conv2d(
in_channels + in_channels + src_channels,
out_channels,
3,
1,
1,
bias=False),
nn.GroupNorm(out_channels // 4, out_channels),
hswish(),
)
self.gru = GRU(out_channels // 2)

def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
x = self.upsample(x)
x = x[:, :, :s.size(2), :s.size(3)]
att = self.att_skip(x)
f = self.shortcut(f)
f = att * f
x = torch.cat([x, f, s], dim=1)
x = self.conv(x)
a, b = x.split(self.out_channels // 2, dim=1)
b, r = self.gru(b, r)
x = torch.cat([a, b], dim=1)
return x, r

def forward_time_series(self, x, f, s, r: Optional[Tensor]):
B, T, _, H, W = s.shape
x = x.flatten(0, 1)
f = f.flatten(0, 1)
s = s.flatten(0, 1)
x = self.upsample(x)
x = x[:, :, :H, :W]
f = self.shortcut(f)
att = self.att_skip(x)
f = att * f
x = torch.cat([x, f, s], dim=1)
x = self.conv(x)
x = x.unflatten(0, (B, T))
a, b = x.split(self.out_channels // 2, dim=2)
b, r = self.gru(b, r)
x = torch.cat([a, b], dim=2)
return x, r

def forward(self, x, f, s, r: Optional[Tensor]):
if x.ndim == 5:
return self.forward_time_series(x, f, s, r)
else:
return self.forward_single_frame(x, f, s, r)


class OutputBlock(nn.Module):

def __init__(self, in_channels, src_channels, out_channels):
super().__init__()
self.upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
self.conv = nn.Sequential(
nn.Conv2d(
in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
nn.GroupNorm(out_channels // 2, out_channels),
hswish(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.GroupNorm(out_channels // 2, out_channels),
hswish(),
)

def forward_single_frame(self, x, s):
x = self.upsample(x)
x = x[:, :, :s.size(2), :s.size(3)]
x = torch.cat([x, s], dim=1)
x = self.conv(x)
return x

def forward_time_series(self, x, s):
B, T, _, H, W = s.shape
x = x.flatten(0, 1)
s = s.flatten(0, 1)
x = self.upsample(x)
x = x[:, :, :H, :W]
x = torch.cat([x, s], dim=1)
x = self.conv(x)
x = x.unflatten(0, (B, T))
return x

def forward(self, x, s):
if x.ndim == 5:
return self.forward_time_series(x, s)
else:
return self.forward_single_frame(x, s)


class Projection(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1)

def forward_single_frame(self, x):
return self.conv(x)

def forward_time_series(self, x):
B, T = x.shape[:2]
return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))

def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)


class GRU(nn.Module):

def __init__(self, channels, kernel_size=3, padding=1):
super().__init__()
self.channels = channels
self.ih = nn.Conv2d(
channels * 2, channels * 2, kernel_size, padding=padding)
self.act_ih = nn.Sigmoid()
self.hh = nn.Conv2d(
channels * 2, channels, kernel_size, padding=padding)
self.act_hh = nn.Tanh()

def forward_single_frame(self, x, pre_fea):
fea_ih = self.ih(torch.cat([x, pre_fea], dim=1))
r, z = self.act_ih(fea_ih).split(self.channels, dim=1)
fea_hh = self.hh(torch.cat([x, r * pre_fea], dim=1))
c = self.act_hh(fea_hh)
fea_gru = (1 - z) * pre_fea + z * c
return fea_gru, fea_gru

def forward_time_series(self, x, pre_fea):
o = []
for xt in x.unbind(dim=1):
ot, pre_fea = self.forward_single_frame(xt, pre_fea)
o.append(ot)
o = torch.stack(o, dim=1)
return o, pre_fea

def forward(self, x, pre_fea):
if pre_fea is None:
pre_fea = torch.zeros(
(x.size(0), x.size(-3), x.size(-2), x.size(-1)),
device=x.device,
dtype=x.dtype)

if x.ndim == 5:
return self.forward_time_series(x, pre_fea)
else:
return self.forward_single_frame(x, pre_fea)

+ 64
- 0
modelscope/models/cv/video_human_matting/models/deep_guided_filter.py View File

@@ -0,0 +1,64 @@
"""
Part of the implementation is borrowed and modified from DeepGuidedFilter
publicly available at <https://github.com/wuhuikai/DeepGuidedFilter/>
"""
import torch
from torch import nn
from torch.nn import functional as F


class DeepGuidedFilterRefiner(nn.Module):

def __init__(self, hid_channels=16):
super().__init__()
self.box_filter = nn.Conv2d(
4, 4, kernel_size=3, padding=1, bias=False, groups=4)
self.box_filter.weight.data[...] = 1 / 9
self.conv = nn.Sequential(
nn.Conv2d(
4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(hid_channels), nn.ReLU(True),
nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(hid_channels), nn.ReLU(True),
nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True))

def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha,
base_hid):
fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
base_y = torch.cat([base_fgr, base_pha], dim=1)

mean_x = self.box_filter(base_x)
mean_y = self.box_filter(base_y)
cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
var_x = self.box_filter(base_x * base_x) - mean_x * mean_x

A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
b = mean_y - A * mean_x

H, W = fine_src.shape[2:]
A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)

out = A * fine_x + b
fgr, pha = out.split([3, 1], dim=1)
return fgr, pha

def forward_time_series(self, fine_src, base_src, base_fgr, base_pha,
base_hid):
B, T = fine_src.shape[:2]
fgr, pha = self.forward_single_frame(
fine_src.flatten(0, 1), base_src.flatten(0, 1),
base_fgr.flatten(0, 1), base_pha.flatten(0, 1),
base_hid.flatten(0, 1))
fgr = fgr.unflatten(0, (B, T))
pha = pha.unflatten(0, (B, T))
return fgr, pha

def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
if fine_src.ndim == 5:
return self.forward_time_series(fine_src, base_src, base_fgr,
base_pha, base_hid)
else:
return self.forward_single_frame(fine_src, base_src, base_fgr,
base_pha, base_hid)

+ 177
- 0
modelscope/models/cv/video_human_matting/models/effv2.py View File

@@ -0,0 +1,177 @@
"""
Part of the implementation is borrowed and modified from EfficientNetV2
publicly available at <https://arxiv.org/abs/2104.00298>
"""

import torch
import torch.nn.functional


class SiLU(torch.nn.Module):
"""
[https://arxiv.org/pdf/1710.05941.pdf]
"""

def __init__(self, inplace: bool = False):
super().__init__()
self.silu = torch.nn.SiLU(inplace=inplace)

def forward(self, x):
return self.silu(x)


class Conv(torch.nn.Module):

def __init__(self, in_ch, out_ch, activation, k=1, s=1, g=1):
super().__init__()
self.conv = torch.nn.Conv2d(
in_ch, out_ch, k, s, k // 2, 1, g, bias=False)
self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.01)
self.silu = activation

def forward(self, x):
return self.silu(self.norm(self.conv(x)))


class SE(torch.nn.Module):
"""
[https://arxiv.org/pdf/1709.01507.pdf]
"""

def __init__(self, ch, r):
super().__init__()
self.se = torch.nn.Sequential(
torch.nn.Conv2d(ch, ch // (4 * r), 1), torch.nn.SiLU(),
torch.nn.Conv2d(ch // (4 * r), ch, 1), torch.nn.Sigmoid())

def forward(self, x):
return x * self.se(x.mean((2, 3), keepdim=True))


class Residual(torch.nn.Module):
"""
[https://arxiv.org/pdf/1801.04381.pdf]
"""

def __init__(self, in_ch, out_ch, s, r, fused=True):
super().__init__()
identity = torch.nn.Identity()
if fused:
if r == 1:
features = [Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s)]
else:
features = [
Conv(in_ch, r * in_ch, torch.nn.SiLU(), 3, s),
Conv(r * in_ch, out_ch, identity)
]
else:
if r == 1:
features = [
Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s,
r * in_ch),
SE(r * in_ch, r),
Conv(r * in_ch, out_ch, identity)
]
else:
features = [
Conv(in_ch, r * in_ch, torch.nn.SiLU()),
Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s,
r * in_ch),
SE(r * in_ch, r),
Conv(r * in_ch, out_ch, identity)
]
self.add = s == 1 and in_ch == out_ch
self.res = torch.nn.Sequential(*features)

def forward(self, x):
return x + self.res(x) if self.add else self.res(x)


class EfficientNet(torch.nn.Module):

def __init__(self, pretrained: bool = False):
super().__init__()
gate_fn = [True, False]
filters = [24, 48, 64, 128, 160, 256]
feature = [Conv(3, filters[0], torch.nn.SiLU(), 3, 2)]
for i in range(2):
if i == 0:
feature.append(
Residual(filters[0], filters[0], 1, 1, gate_fn[0]))
else:
feature.append(
Residual(filters[0], filters[0], 1, 1, gate_fn[0]))

for i in range(4):
if i == 0:
feature.append(
Residual(filters[0], filters[1], 2, 4, gate_fn[0]))
else:
feature.append(
Residual(filters[1], filters[1], 1, 4, gate_fn[0]))

for i in range(4):
if i == 0:
feature.append(
Residual(filters[1], filters[2], 2, 4, gate_fn[0]))
else:
feature.append(
Residual(filters[2], filters[2], 1, 4, gate_fn[0]))

for i in range(6):
if i == 0:
feature.append(
Residual(filters[2], filters[3], 2, 4, gate_fn[1]))
else:
feature.append(
Residual(filters[3], filters[3], 1, 4, gate_fn[1]))

for i in range(9):
if i == 0:
feature.append(
Residual(filters[3], filters[4], 1, 6, gate_fn[1]))
else:
feature.append(
Residual(filters[4], filters[4], 1, 6, gate_fn[1]))

self.feature = torch.nn.Sequential(*feature)

def forward_single_frame(self, x):
x = self.feature[0](x)
x = self.feature[1](x)
x = self.feature[2](x)
f1 = x # 1/2 24
for i in range(4):
x = self.feature[i + 3](x)
f2 = x # 1/4 48
for i in range(4):
x = self.feature[i + 7](x)
f3 = x # 1/8 64
for i in range(6):
x = self.feature[i + 11](x)
for i in range(9):
x = self.feature[i + 17](x)
f5 = x # 1/16 160
return [f1, f2, f3, f5]

def forward_time_series(self, x):
B, T = x.shape[:2]
features = self.forward_single_frame(x.flatten(0, 1))
features = [f.unflatten(0, (B, T)) for f in features]
return features

def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)

def export(self):
for m in self.modules():
if type(m) is Conv and hasattr(m, 'silu'):
if isinstance(m.silu, torch.nn.SiLU):
m.silu = SiLU()
if type(m) is SE:
if isinstance(m.se[1], torch.nn.SiLU):
m.se[1] = SiLU()
return self

+ 94
- 0
modelscope/models/cv/video_human_matting/models/lraspp.py View File

@@ -0,0 +1,94 @@
"""
Part of the implementation is borrowed and modified from Deeplab v3
publicly available at <https://arxiv.org/abs/1706.05587v3>
"""
import torch
from torch import nn


class ASP_OC_Module(nn.Module):

def __init__(self, features, out_features=96, dilations=(2, 4, 8)):
super(ASP_OC_Module, self).__init__()
self.conv2 = nn.Sequential(
nn.Conv2d(
features,
out_features,
kernel_size=1,
padding=0,
dilation=1,
bias=False), nn.BatchNorm2d(out_features))
self.conv3 = nn.Sequential(
nn.Conv2d(
features,
out_features,
kernel_size=3,
padding=dilations[0],
dilation=dilations[0],
bias=False), nn.BatchNorm2d(out_features))
self.conv4 = nn.Sequential(
nn.Conv2d(
features,
out_features,
kernel_size=3,
padding=dilations[1],
dilation=dilations[1],
bias=False), nn.BatchNorm2d(out_features))
self.conv5 = nn.Sequential(
nn.Conv2d(
features,
out_features,
kernel_size=3,
padding=dilations[2],
dilation=dilations[2],
bias=False), nn.BatchNorm2d(out_features))

self.conv_bn_dropout = nn.Sequential(
nn.Conv2d(
out_features * 4,
out_features * 2,
kernel_size=1,
padding=0,
dilation=1,
bias=False), nn.InstanceNorm2d(out_features * 2),
nn.Dropout2d(0.05))

def _cat_each(self, feat1, feat2, feat3, feat4, feat5):
assert (len(feat1) == len(feat2))
z = []
for i in range(len(feat1)):
z.append(
torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]),
1))
return z

def forward(self, x):
_, _, h, w = x.size()
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
feat5 = self.conv5(x)
out = torch.cat((feat2, feat3, feat4, feat5), 1)
output = self.conv_bn_dropout(out)
return output


class LRASPP(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.aspp = ASP_OC_Module(in_channels, out_channels)

def forward_single_frame(self, x):
return self.aspp(x)

def forward_time_series(self, x):
B, T = x.shape[:2]
x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T))
return x

def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)

+ 67
- 0
modelscope/models/cv/video_human_matting/models/matting.py View File

@@ -0,0 +1,67 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn import functional as F

from .decoder import Projection, RecurrentDecoder
from .deep_guided_filter import DeepGuidedFilterRefiner
from .effv2 import EfficientNet
from .lraspp import LRASPP


class MattingNetwork(torch.nn.Module):

def __init__(self, pretrained_backbone: bool = False):
super().__init__()
self.backbone = EfficientNet(pretrained_backbone)
self.aspp = LRASPP(160, 64)
self.decoder = RecurrentDecoder([24, 48, 64, 128], [64, 32, 24, 16])
self.project_mat = Projection(16, 4)
self.project_seg = Projection(16, 1)
self.refiner = DeepGuidedFilterRefiner()

def forward(self,
src: Tensor,
r0: Optional[Tensor] = None,
r1: Optional[Tensor] = None,
r2: Optional[Tensor] = None,
r3: Optional[Tensor] = None,
downsample_ratio: float = 1,
segmentation_pass: bool = False):

if downsample_ratio != 1:
src_sm = self._interpolate(src, scale_factor=downsample_ratio)
else:
src_sm = src

f1, f2, f3, f4 = self.backbone(src_sm)
f4 = self.aspp(f4)
hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r0, r1, r2, r3)

if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
if downsample_ratio != 1:
_, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
pha = pha.clamp(0., 1.)
return [pha, *rec]
else:
seg = self.project_seg(hid)
return [seg, *rec]

def _interpolate(self, x: Tensor, scale_factor: float):
if x.ndim == 5:
B, T = x.shape[:2]
x = F.interpolate(
x.flatten(0, 1),
scale_factor=scale_factor,
mode='bilinear',
align_corners=False)
x = x.unflatten(0, (B, T))
else:
x = F.interpolate(
x,
scale_factor=scale_factor,
mode='bilinear',
align_corners=False)
return x

+ 6
- 0
modelscope/outputs/outputs.py View File

@@ -443,6 +443,12 @@ TASK_OUTPUTS = {
Tasks.referring_video_object_segmentation:
[OutputKeys.MASKS, OutputKeys.TIMESTAMPS],

# video human matting result for a single video
# {
# "masks": [np.array # 2D array with shape [height, width]]
# }
Tasks.video_human_matting: [OutputKeys.MASKS],

# ============ nlp tasks ===================

# text classification result for single sample


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -201,6 +201,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_fft_inpainting_lama'),
Tasks.video_inpainting: (Pipelines.video_inpainting,
'damo/cv_video-inpainting'),
Tasks.video_human_matting: (Pipelines.video_human_matting,
'damo/cv_effnetv2_video-human-matting'),
Tasks.human_wholebody_keypoint:
(Pipelines.human_wholebody_keypoint,
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),


+ 77
- 0
modelscope/pipelines/cv/video_human_matting_pipeline.py View File

@@ -0,0 +1,77 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_human_matting import preprocess
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.video_human_matting, module_name=Pipelines.video_human_matting)
class VideoHumanMattingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a video human matting pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
if torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
logger.info('load model done')

def preprocess(self, input) -> Input:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
video_path = input['video_input_path']
out_path = input['output_path']
video_input = cv2.VideoCapture(video_path)
fps = video_input.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
success, frame = video_input.read()
h, w = frame.shape[:2]
scale = 512 / max(h, w)
video_save = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
masks = []
rec = [None] * 4
self.model = self.model.to(self.device)
logger.info('matting start using ', self.device)
with torch.no_grad():
while True:
if frame is None:
break
frame_tensor = preprocess(frame)
pha, *rec = self.model.model(
frame_tensor.to(self.device), *rec, downsample_ratio=scale)
com = pha * 255
com = com.repeat(1, 3, 1, 1)
com = com[0].data.cpu().numpy().transpose(1, 2,
0).astype(np.uint8)
video_save.write(com)
masks.append(com / 255)
success, frame = video_input.read()
logger.info('matting process done')
video_input.release()
video_save.release()

return {
OutputKeys.MASKS: masks,
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -87,6 +87,7 @@ class CVTasks(object):

# video segmentation
referring_video_object_segmentation = 'referring-video-object-segmentation'
video_human_matting = 'video-human-matting'

# video editing
video_inpainting = 'video-inpainting'


+ 39
- 0
tests/pipelines/test_video_human_matting.py View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sys
import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class VideoHumanMattingTest(unittest.TestCase):

def setUp(self) -> None:
self.model = 'damo/cv_effnetv2_video-human-matting'
self.video_in = 'data/test/videos/video_matting_test.mp4'
self.video_out = 'matting_out.mp4'
self.input = {
'video_input_path': self.video_in,
'output_path': self.video_out,
}

def pipeline_inference(self, pipeline: Pipeline, input):
result = pipeline(input)
print('video matting over, results:', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
video_human_matting = pipeline(
Tasks.video_human_matting, model=self.model)
self.pipeline_inference(video_human_matting, self.input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
video_human_matting = pipeline(Tasks.video_human_matting)
self.pipeline_inference(video_human_matting, self.input)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save