|
|
@@ -0,0 +1,442 @@ |
|
|
|
import random |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models import MODELS |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
|
|
|
|
|
|
|
|
def apply_offset(offset): |
|
|
|
sizes = list(offset.size()[2:]) |
|
|
|
grid_list = torch.meshgrid( |
|
|
|
[torch.arange(size, device=offset.device) for size in sizes]) |
|
|
|
grid_list = reversed(grid_list) |
|
|
|
# apply offset |
|
|
|
grid_list = [ |
|
|
|
grid.float().unsqueeze(0) + offset[:, dim, ...] |
|
|
|
for dim, grid in enumerate(grid_list) |
|
|
|
] |
|
|
|
# normalize |
|
|
|
grid_list = [ |
|
|
|
grid / ((size - 1.0) / 2.0) - 1.0 |
|
|
|
for grid, size in zip(grid_list, reversed(sizes)) |
|
|
|
] |
|
|
|
|
|
|
|
return torch.stack(grid_list, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
# backbone |
|
|
|
class ResBlock(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, in_channels): |
|
|
|
super(ResBlock, self).__init__() |
|
|
|
self.block = nn.Sequential( |
|
|
|
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), |
|
|
|
nn.Conv2d( |
|
|
|
in_channels, in_channels, kernel_size=3, |
|
|
|
padding=1, bias=False), nn.BatchNorm2d(in_channels), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.Conv2d( |
|
|
|
in_channels, in_channels, kernel_size=3, padding=1, |
|
|
|
bias=False)) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return self.block(x) + x |
|
|
|
|
|
|
|
|
|
|
|
class Downsample(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels): |
|
|
|
super(Downsample, self).__init__() |
|
|
|
self.block = nn.Sequential( |
|
|
|
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), |
|
|
|
nn.Conv2d( |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=3, |
|
|
|
stride=2, |
|
|
|
padding=1, |
|
|
|
bias=False)) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return self.block(x) |
|
|
|
|
|
|
|
|
|
|
|
class FeatureEncoder(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, in_channels, chns=[64, 128, 256, 256, 256]): |
|
|
|
# in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation |
|
|
|
super(FeatureEncoder, self).__init__() |
|
|
|
self.encoders = [] |
|
|
|
for i, out_chns in enumerate(chns): |
|
|
|
if i == 0: |
|
|
|
encoder = nn.Sequential( |
|
|
|
Downsample(in_channels, out_chns), ResBlock(out_chns), |
|
|
|
ResBlock(out_chns)) |
|
|
|
else: |
|
|
|
encoder = nn.Sequential( |
|
|
|
Downsample(chns[i - 1], out_chns), ResBlock(out_chns), |
|
|
|
ResBlock(out_chns)) |
|
|
|
|
|
|
|
self.encoders.append(encoder) |
|
|
|
|
|
|
|
self.encoders = nn.ModuleList(self.encoders) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
encoder_features = [] |
|
|
|
for encoder in self.encoders: |
|
|
|
x = encoder(x) |
|
|
|
encoder_features.append(x) |
|
|
|
return encoder_features |
|
|
|
|
|
|
|
|
|
|
|
class RefinePyramid(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, chns=[64, 128, 256, 256, 256], fpn_dim=256): |
|
|
|
super(RefinePyramid, self).__init__() |
|
|
|
self.chns = chns |
|
|
|
|
|
|
|
# adaptive |
|
|
|
self.adaptive = [] |
|
|
|
for in_chns in list(reversed(chns)): |
|
|
|
adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1) |
|
|
|
self.adaptive.append(adaptive_layer) |
|
|
|
self.adaptive = nn.ModuleList(self.adaptive) |
|
|
|
# output conv |
|
|
|
self.smooth = [] |
|
|
|
for i in range(len(chns)): |
|
|
|
smooth_layer = nn.Conv2d( |
|
|
|
fpn_dim, fpn_dim, kernel_size=3, padding=1) |
|
|
|
self.smooth.append(smooth_layer) |
|
|
|
self.smooth = nn.ModuleList(self.smooth) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
conv_ftr_list = x |
|
|
|
|
|
|
|
feature_list = [] |
|
|
|
last_feature = None |
|
|
|
for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))): |
|
|
|
# adaptive |
|
|
|
feature = self.adaptive[i](conv_ftr) |
|
|
|
# fuse |
|
|
|
if last_feature is not None: |
|
|
|
feature = feature + F.interpolate( |
|
|
|
last_feature, scale_factor=2, mode='nearest') |
|
|
|
# smooth |
|
|
|
feature = self.smooth[i](feature) |
|
|
|
last_feature = feature |
|
|
|
feature_list.append(feature) |
|
|
|
|
|
|
|
return tuple(reversed(feature_list)) |
|
|
|
|
|
|
|
|
|
|
|
def DAWarp(feat, offsets, att_maps, sample_k, out_ch): |
|
|
|
att_maps = torch.repeat_interleave(att_maps, out_ch, 1) |
|
|
|
B, C, H, W = feat.size() |
|
|
|
multi_feat = torch.repeat_interleave(feat, sample_k, 0) |
|
|
|
multi_warp_feat = F.grid_sample( |
|
|
|
multi_feat, |
|
|
|
offsets.detach().permute(0, 2, 3, 1), |
|
|
|
mode='bilinear', |
|
|
|
padding_mode='border') |
|
|
|
multi_att_warp_feat = multi_warp_feat.reshape(B, -1, H, W) * att_maps |
|
|
|
att_warp_feat = sum(torch.split(multi_att_warp_feat, out_ch, 1)) |
|
|
|
return att_warp_feat |
|
|
|
|
|
|
|
|
|
|
|
class MFEBlock(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=3, |
|
|
|
num_filters=[128, 64, 32]): |
|
|
|
super(MFEBlock, self).__init__() |
|
|
|
layers = [] |
|
|
|
for i in range(len(num_filters)): |
|
|
|
if i == 0: |
|
|
|
layers.append( |
|
|
|
torch.nn.Conv2d( |
|
|
|
in_channels=in_channels, |
|
|
|
out_channels=num_filters[i], |
|
|
|
kernel_size=3, |
|
|
|
stride=1, |
|
|
|
padding=1)) |
|
|
|
else: |
|
|
|
layers.append( |
|
|
|
torch.nn.Conv2d( |
|
|
|
in_channels=num_filters[i - 1], |
|
|
|
out_channels=num_filters[i], |
|
|
|
kernel_size=kernel_size, |
|
|
|
stride=1, |
|
|
|
padding=kernel_size // 2)) |
|
|
|
layers.append( |
|
|
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
|
|
layers.append( |
|
|
|
torch.nn.Conv2d( |
|
|
|
in_channels=num_filters[-1], |
|
|
|
out_channels=out_channels, |
|
|
|
kernel_size=kernel_size, |
|
|
|
stride=1, |
|
|
|
padding=kernel_size // 2)) |
|
|
|
self.layers = torch.nn.Sequential(*layers) |
|
|
|
|
|
|
|
def forward(self, input): |
|
|
|
return self.layers(input) |
|
|
|
|
|
|
|
|
|
|
|
class DAFlowNet(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, num_pyramid, fpn_dim=256, head_nums=1): |
|
|
|
super(DAFlowNet, self).__init__() |
|
|
|
self.Self_MFEs = [] |
|
|
|
|
|
|
|
self.Cross_MFEs = [] |
|
|
|
self.Refine_MFEs = [] |
|
|
|
self.k = head_nums |
|
|
|
self.out_ch = fpn_dim |
|
|
|
for i in range(num_pyramid): |
|
|
|
# self-MFE for model img 2k:flow 1k:att_map |
|
|
|
Self_MFE_layer = MFEBlock( |
|
|
|
in_channels=2 * fpn_dim, |
|
|
|
out_channels=self.k * 3, |
|
|
|
kernel_size=7) |
|
|
|
# cross-MFE for cloth img |
|
|
|
Cross_MFE_layer = MFEBlock( |
|
|
|
in_channels=2 * fpn_dim, out_channels=self.k * 3) |
|
|
|
# refine-MFE for cloth and model imgs |
|
|
|
Refine_MFE_layer = MFEBlock( |
|
|
|
in_channels=2 * fpn_dim, out_channels=self.k * 6) |
|
|
|
self.Self_MFEs.append(Self_MFE_layer) |
|
|
|
self.Cross_MFEs.append(Cross_MFE_layer) |
|
|
|
self.Refine_MFEs.append(Refine_MFE_layer) |
|
|
|
|
|
|
|
self.Self_MFEs = nn.ModuleList(self.Self_MFEs) |
|
|
|
self.Cross_MFEs = nn.ModuleList(self.Cross_MFEs) |
|
|
|
self.Refine_MFEs = nn.ModuleList(self.Refine_MFEs) |
|
|
|
|
|
|
|
self.lights_decoder = torch.nn.Sequential( |
|
|
|
torch.nn.Conv2d(64, out_channels=32, kernel_size=1, stride=1), |
|
|
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
|
|
torch.nn.Conv2d( |
|
|
|
in_channels=32, |
|
|
|
out_channels=3, |
|
|
|
kernel_size=3, |
|
|
|
stride=1, |
|
|
|
padding=1)) |
|
|
|
self.lights_encoder = torch.nn.Sequential( |
|
|
|
torch.nn.Conv2d( |
|
|
|
3, out_channels=32, kernel_size=3, stride=1, padding=1), |
|
|
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
|
|
torch.nn.Conv2d( |
|
|
|
in_channels=32, out_channels=64, kernel_size=1, stride=1)) |
|
|
|
|
|
|
|
def forward(self, |
|
|
|
source_image, |
|
|
|
reference_image, |
|
|
|
source_feats, |
|
|
|
reference_feats, |
|
|
|
return_all=False, |
|
|
|
warp_feature=True, |
|
|
|
use_light_en_de=True): |
|
|
|
r""" |
|
|
|
Args: |
|
|
|
source_image: cloth rgb image for tryon |
|
|
|
reference_image: model rgb image for try on |
|
|
|
source_feats: cloth FPN features |
|
|
|
reference_feats: model and pose features |
|
|
|
return_all: bool return all intermediate try-on results in training phase |
|
|
|
warp_feature: use DAFlow for both features and images |
|
|
|
use_light_en_de: use shallow encoder and decoder to project the images from RGB to high dimensional space |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
# reference branch inputs model img using self-DAFlow |
|
|
|
last_multi_self_offsets = None |
|
|
|
# source branch inputs cloth img using cross-DAFlow |
|
|
|
last_multi_cross_offsets = None |
|
|
|
|
|
|
|
if return_all: |
|
|
|
results_all = [] |
|
|
|
|
|
|
|
for i in range(len(source_feats)): |
|
|
|
|
|
|
|
feat_source = source_feats[len(source_feats) - 1 - i] |
|
|
|
feat_ref = reference_feats[len(reference_feats) - 1 - i] |
|
|
|
B, C, H, W = feat_source.size() |
|
|
|
|
|
|
|
# Pre-DAWarp for Pyramid feature |
|
|
|
if last_multi_cross_offsets is not None and warp_feature: |
|
|
|
att_source_feat = DAWarp(feat_source, last_multi_cross_offsets, |
|
|
|
cross_att_maps, self.k, self.out_ch) |
|
|
|
att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets, |
|
|
|
self_att_maps, self.k, self.out_ch) |
|
|
|
else: |
|
|
|
att_source_feat = feat_source |
|
|
|
att_reference_feat = feat_ref |
|
|
|
# Cross-MFE |
|
|
|
input_feat = torch.cat([att_source_feat, feat_ref], 1) |
|
|
|
offsets_att = self.Cross_MFEs[i](input_feat) |
|
|
|
cross_att_maps = F.softmax( |
|
|
|
offsets_att[:, self.k * 2:, :, :], dim=1) |
|
|
|
offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape( |
|
|
|
-1, 2, H, W)) |
|
|
|
if last_multi_cross_offsets is not None: |
|
|
|
offsets = F.grid_sample( |
|
|
|
last_multi_cross_offsets, |
|
|
|
offsets, |
|
|
|
mode='bilinear', |
|
|
|
padding_mode='border') |
|
|
|
else: |
|
|
|
offsets = offsets.permute(0, 3, 1, 2) |
|
|
|
last_multi_cross_offsets = offsets |
|
|
|
att_source_feat = DAWarp(feat_source, last_multi_cross_offsets, |
|
|
|
cross_att_maps, self.k, self.out_ch) |
|
|
|
|
|
|
|
# Self-MFE |
|
|
|
input_feat = torch.cat([att_source_feat, att_reference_feat], 1) |
|
|
|
offsets_att = self.Self_MFEs[i](input_feat) |
|
|
|
self_att_maps = F.softmax(offsets_att[:, self.k * 2:, :, :], dim=1) |
|
|
|
offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape( |
|
|
|
-1, 2, H, W)) |
|
|
|
if last_multi_self_offsets is not None: |
|
|
|
offsets = F.grid_sample( |
|
|
|
last_multi_self_offsets, |
|
|
|
offsets, |
|
|
|
mode='bilinear', |
|
|
|
padding_mode='border') |
|
|
|
else: |
|
|
|
offsets = offsets.permute(0, 3, 1, 2) |
|
|
|
last_multi_self_offsets = offsets |
|
|
|
att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets, |
|
|
|
self_att_maps, self.k, self.out_ch) |
|
|
|
|
|
|
|
# Refine-MFE |
|
|
|
input_feat = torch.cat([att_source_feat, att_reference_feat], 1) |
|
|
|
offsets_att = self.Refine_MFEs[i](input_feat) |
|
|
|
att_maps = F.softmax(offsets_att[:, self.k * 4:, :, :], dim=1) |
|
|
|
cross_offsets = apply_offset( |
|
|
|
offsets_att[:, :self.k * 2, :, :].reshape(-1, 2, H, W)) |
|
|
|
self_offsets = apply_offset( |
|
|
|
offsets_att[:, |
|
|
|
self.k * 2:self.k * 4, :, :].reshape(-1, 2, H, W)) |
|
|
|
last_multi_cross_offsets = F.grid_sample( |
|
|
|
last_multi_cross_offsets, |
|
|
|
cross_offsets, |
|
|
|
mode='bilinear', |
|
|
|
padding_mode='border') |
|
|
|
last_multi_self_offsets = F.grid_sample( |
|
|
|
last_multi_self_offsets, |
|
|
|
self_offsets, |
|
|
|
mode='bilinear', |
|
|
|
padding_mode='border') |
|
|
|
|
|
|
|
# Upsampling |
|
|
|
last_multi_cross_offsets = F.interpolate( |
|
|
|
last_multi_cross_offsets, scale_factor=2, mode='bilinear') |
|
|
|
last_multi_self_offsets = F.interpolate( |
|
|
|
last_multi_self_offsets, scale_factor=2, mode='bilinear') |
|
|
|
self_att_maps = F.interpolate( |
|
|
|
att_maps[:, :self.k, :, :], scale_factor=2, mode='bilinear') |
|
|
|
cross_att_maps = F.interpolate( |
|
|
|
att_maps[:, self.k:, :, :], scale_factor=2, mode='bilinear') |
|
|
|
|
|
|
|
# Post-DAWarp for source and reference images |
|
|
|
if return_all: |
|
|
|
cur_source_image = F.interpolate( |
|
|
|
source_image, (H * 2, W * 2), mode='bilinear') |
|
|
|
cur_reference_image = F.interpolate( |
|
|
|
reference_image, (H * 2, W * 2), mode='bilinear') |
|
|
|
if use_light_en_de: |
|
|
|
cur_source_image = self.lights_encoder(cur_source_image) |
|
|
|
cur_reference_image = self.lights_encoder( |
|
|
|
cur_reference_image) |
|
|
|
# the feat dim in light encoder is 64 |
|
|
|
warp_att_source_image = DAWarp(cur_source_image, |
|
|
|
last_multi_cross_offsets, |
|
|
|
cross_att_maps, self.k, 64) |
|
|
|
warp_att_reference_image = DAWarp(cur_reference_image, |
|
|
|
last_multi_self_offsets, |
|
|
|
self_att_maps, self.k, |
|
|
|
64) |
|
|
|
result_tryon = self.lights_decoder( |
|
|
|
warp_att_source_image + warp_att_reference_image) |
|
|
|
else: |
|
|
|
warp_att_source_image = DAWarp(cur_source_image, |
|
|
|
last_multi_cross_offsets, |
|
|
|
cross_att_maps, self.k, 3) |
|
|
|
warp_att_reference_image = DAWarp(cur_reference_image, |
|
|
|
last_multi_self_offsets, |
|
|
|
self_att_maps, self.k, 3) |
|
|
|
result_tryon = warp_att_source_image + warp_att_reference_image |
|
|
|
results_all.append(result_tryon) |
|
|
|
|
|
|
|
last_multi_self_offsets = F.interpolate( |
|
|
|
last_multi_self_offsets, |
|
|
|
reference_image.size()[2:], |
|
|
|
mode='bilinear') |
|
|
|
last_multi_cross_offsets = F.interpolate( |
|
|
|
last_multi_cross_offsets, source_image.size()[2:], mode='bilinear') |
|
|
|
self_att_maps = F.interpolate( |
|
|
|
self_att_maps, reference_image.size()[2:], mode='bilinear') |
|
|
|
cross_att_maps = F.interpolate( |
|
|
|
cross_att_maps, source_image.size()[2:], mode='bilinear') |
|
|
|
if use_light_en_de: |
|
|
|
source_image = self.lights_encoder(source_image) |
|
|
|
reference_image = self.lights_encoder(reference_image) |
|
|
|
warp_att_source_image = DAWarp(source_image, |
|
|
|
last_multi_cross_offsets, |
|
|
|
cross_att_maps, self.k, 64) |
|
|
|
warp_att_reference_image = DAWarp(reference_image, |
|
|
|
last_multi_self_offsets, |
|
|
|
self_att_maps, self.k, 64) |
|
|
|
result_tryon = self.lights_decoder(warp_att_source_image |
|
|
|
+ warp_att_reference_image) |
|
|
|
else: |
|
|
|
warp_att_source_image = DAWarp(source_image, |
|
|
|
last_multi_cross_offsets, |
|
|
|
cross_att_maps, self.k, 3) |
|
|
|
warp_att_reference_image = DAWarp(reference_image, |
|
|
|
last_multi_self_offsets, |
|
|
|
self_att_maps, self.k, 3) |
|
|
|
result_tryon = warp_att_source_image + warp_att_reference_image |
|
|
|
|
|
|
|
if return_all: |
|
|
|
return result_tryon, return_all |
|
|
|
return result_tryon |
|
|
|
|
|
|
|
|
|
|
|
class SDAFNet_Tryon(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, ref_in_channel, source_in_channel=3, head_nums=6): |
|
|
|
super(SDAFNet_Tryon, self).__init__() |
|
|
|
num_filters = [64, 128, 256, 256, 256] |
|
|
|
self.source_features = FeatureEncoder(source_in_channel, num_filters) |
|
|
|
self.reference_features = FeatureEncoder(ref_in_channel, num_filters) |
|
|
|
self.source_FPN = RefinePyramid(num_filters) |
|
|
|
self.reference_FPN = RefinePyramid(num_filters) |
|
|
|
self.dafnet = DAFlowNet(len(num_filters), head_nums=head_nums) |
|
|
|
|
|
|
|
def forward(self, |
|
|
|
ref_input, |
|
|
|
source_image, |
|
|
|
ref_image, |
|
|
|
use_light_en_de=True, |
|
|
|
return_all=False, |
|
|
|
warp_feature=True): |
|
|
|
reference_feats = self.reference_FPN( |
|
|
|
self.reference_features(ref_input)) |
|
|
|
source_feats = self.source_FPN(self.source_features(source_image)) |
|
|
|
result = self.dafnet( |
|
|
|
source_image, |
|
|
|
ref_image, |
|
|
|
source_feats, |
|
|
|
reference_feats, |
|
|
|
use_light_en_de=use_light_en_de, |
|
|
|
return_all=return_all, |
|
|
|
warp_feature=warp_feature) |
|
|
|
return result |