From 2d1ce75dcc987dab7928a8b6db086ca137e957c8 Mon Sep 17 00:00:00 2001 From: "baishuai.bs" Date: Mon, 25 Jul 2022 13:07:20 +0800 Subject: [PATCH] [to #43259593]add cv tryon task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加虚拟试衣任务,输入模特图,骨骼图,衣服展示图,生成试衣效果图 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9401415 --- data/test/images/virtual_tryon_cloth.jpg | 3 + data/test/images/virtual_tryon_model.jpg | 3 + data/test/images/virtual_tryon_pose.jpg | 3 + modelscope/metainfo.py | 1 + modelscope/models/cv/virual_tryon/__init__.py | 0 modelscope/models/cv/virual_tryon/sdafnet.py | 442 ++++++++++++++++++ modelscope/outputs.py | 7 +- modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 1 + .../pipelines/cv/virtual_tryon_pipeline.py | 124 +++++ modelscope/utils/constant.py | 1 + tests/pipelines/test_virtual_tryon.py | 36 ++ 12 files changed, 622 insertions(+), 1 deletion(-) create mode 100644 data/test/images/virtual_tryon_cloth.jpg create mode 100644 data/test/images/virtual_tryon_model.jpg create mode 100644 data/test/images/virtual_tryon_pose.jpg create mode 100644 modelscope/models/cv/virual_tryon/__init__.py create mode 100644 modelscope/models/cv/virual_tryon/sdafnet.py create mode 100644 modelscope/pipelines/cv/virtual_tryon_pipeline.py create mode 100644 tests/pipelines/test_virtual_tryon.py diff --git a/data/test/images/virtual_tryon_cloth.jpg b/data/test/images/virtual_tryon_cloth.jpg new file mode 100644 index 00000000..baa4d3aa --- /dev/null +++ b/data/test/images/virtual_tryon_cloth.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ce0d25b3392f140bf35fba9c6711fdcfc2efde536600aa48dace35462e81adf +size 8825 diff --git a/data/test/images/virtual_tryon_model.jpg b/data/test/images/virtual_tryon_model.jpg new file mode 100644 index 00000000..2862a8be --- /dev/null +++ b/data/test/images/virtual_tryon_model.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb76a61306d3d311d440c5c695958909166e04fb34c827d74d766ba830945d6f +size 5034 diff --git a/data/test/images/virtual_tryon_pose.jpg b/data/test/images/virtual_tryon_pose.jpg new file mode 100644 index 00000000..41804706 --- /dev/null +++ b/data/test/images/virtual_tryon_pose.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ab9baf18074b6b5655ee546794789395757486d6e2180c2627aad47b819e505 +size 11778 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 3cb10d65..511b786d 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -60,6 +60,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + virtual_tryon = 'virtual_tryon' image_colorization = 'unet-image-colorization' image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' diff --git a/modelscope/models/cv/virual_tryon/__init__.py b/modelscope/models/cv/virual_tryon/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/virual_tryon/sdafnet.py b/modelscope/models/cv/virual_tryon/sdafnet.py new file mode 100644 index 00000000..f98a5e7d --- /dev/null +++ b/modelscope/models/cv/virual_tryon/sdafnet.py @@ -0,0 +1,442 @@ +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import MODELS +from modelscope.utils.constant import ModelFile, Tasks + + +def apply_offset(offset): + sizes = list(offset.size()[2:]) + grid_list = torch.meshgrid( + [torch.arange(size, device=offset.device) for size in sizes]) + grid_list = reversed(grid_list) + # apply offset + grid_list = [ + grid.float().unsqueeze(0) + offset[:, dim, ...] + for dim, grid in enumerate(grid_list) + ] + # normalize + grid_list = [ + grid / ((size - 1.0) / 2.0) - 1.0 + for grid, size in zip(grid_list, reversed(sizes)) + ] + + return torch.stack(grid_list, dim=-1) + + +# backbone +class ResBlock(nn.Module): + + def __init__(self, in_channels): + super(ResBlock, self).__init__() + self.block = nn.Sequential( + nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, in_channels, kernel_size=3, + padding=1, bias=False), nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, in_channels, kernel_size=3, padding=1, + bias=False)) + + def forward(self, x): + return self.block(x) + x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, out_channels): + super(Downsample, self).__init__() + self.block = nn.Sequential( + nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False)) + + def forward(self, x): + return self.block(x) + + +class FeatureEncoder(nn.Module): + + def __init__(self, in_channels, chns=[64, 128, 256, 256, 256]): + # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation + super(FeatureEncoder, self).__init__() + self.encoders = [] + for i, out_chns in enumerate(chns): + if i == 0: + encoder = nn.Sequential( + Downsample(in_channels, out_chns), ResBlock(out_chns), + ResBlock(out_chns)) + else: + encoder = nn.Sequential( + Downsample(chns[i - 1], out_chns), ResBlock(out_chns), + ResBlock(out_chns)) + + self.encoders.append(encoder) + + self.encoders = nn.ModuleList(self.encoders) + + def forward(self, x): + encoder_features = [] + for encoder in self.encoders: + x = encoder(x) + encoder_features.append(x) + return encoder_features + + +class RefinePyramid(nn.Module): + + def __init__(self, chns=[64, 128, 256, 256, 256], fpn_dim=256): + super(RefinePyramid, self).__init__() + self.chns = chns + + # adaptive + self.adaptive = [] + for in_chns in list(reversed(chns)): + adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1) + self.adaptive.append(adaptive_layer) + self.adaptive = nn.ModuleList(self.adaptive) + # output conv + self.smooth = [] + for i in range(len(chns)): + smooth_layer = nn.Conv2d( + fpn_dim, fpn_dim, kernel_size=3, padding=1) + self.smooth.append(smooth_layer) + self.smooth = nn.ModuleList(self.smooth) + + def forward(self, x): + conv_ftr_list = x + + feature_list = [] + last_feature = None + for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))): + # adaptive + feature = self.adaptive[i](conv_ftr) + # fuse + if last_feature is not None: + feature = feature + F.interpolate( + last_feature, scale_factor=2, mode='nearest') + # smooth + feature = self.smooth[i](feature) + last_feature = feature + feature_list.append(feature) + + return tuple(reversed(feature_list)) + + +def DAWarp(feat, offsets, att_maps, sample_k, out_ch): + att_maps = torch.repeat_interleave(att_maps, out_ch, 1) + B, C, H, W = feat.size() + multi_feat = torch.repeat_interleave(feat, sample_k, 0) + multi_warp_feat = F.grid_sample( + multi_feat, + offsets.detach().permute(0, 2, 3, 1), + mode='bilinear', + padding_mode='border') + multi_att_warp_feat = multi_warp_feat.reshape(B, -1, H, W) * att_maps + att_warp_feat = sum(torch.split(multi_att_warp_feat, out_ch, 1)) + return att_warp_feat + + +class MFEBlock(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + num_filters=[128, 64, 32]): + super(MFEBlock, self).__init__() + layers = [] + for i in range(len(num_filters)): + if i == 0: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=num_filters[i], + kernel_size=3, + stride=1, + padding=1)) + else: + layers.append( + torch.nn.Conv2d( + in_channels=num_filters[i - 1], + out_channels=num_filters[i], + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2)) + layers.append( + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)) + layers.append( + torch.nn.Conv2d( + in_channels=num_filters[-1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2)) + self.layers = torch.nn.Sequential(*layers) + + def forward(self, input): + return self.layers(input) + + +class DAFlowNet(nn.Module): + + def __init__(self, num_pyramid, fpn_dim=256, head_nums=1): + super(DAFlowNet, self).__init__() + self.Self_MFEs = [] + + self.Cross_MFEs = [] + self.Refine_MFEs = [] + self.k = head_nums + self.out_ch = fpn_dim + for i in range(num_pyramid): + # self-MFE for model img 2k:flow 1k:att_map + Self_MFE_layer = MFEBlock( + in_channels=2 * fpn_dim, + out_channels=self.k * 3, + kernel_size=7) + # cross-MFE for cloth img + Cross_MFE_layer = MFEBlock( + in_channels=2 * fpn_dim, out_channels=self.k * 3) + # refine-MFE for cloth and model imgs + Refine_MFE_layer = MFEBlock( + in_channels=2 * fpn_dim, out_channels=self.k * 6) + self.Self_MFEs.append(Self_MFE_layer) + self.Cross_MFEs.append(Cross_MFE_layer) + self.Refine_MFEs.append(Refine_MFE_layer) + + self.Self_MFEs = nn.ModuleList(self.Self_MFEs) + self.Cross_MFEs = nn.ModuleList(self.Cross_MFEs) + self.Refine_MFEs = nn.ModuleList(self.Refine_MFEs) + + self.lights_decoder = torch.nn.Sequential( + torch.nn.Conv2d(64, out_channels=32, kernel_size=1, stride=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=3, + kernel_size=3, + stride=1, + padding=1)) + self.lights_encoder = torch.nn.Sequential( + torch.nn.Conv2d( + 3, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=1, stride=1)) + + def forward(self, + source_image, + reference_image, + source_feats, + reference_feats, + return_all=False, + warp_feature=True, + use_light_en_de=True): + r""" + Args: + source_image: cloth rgb image for tryon + reference_image: model rgb image for try on + source_feats: cloth FPN features + reference_feats: model and pose features + return_all: bool return all intermediate try-on results in training phase + warp_feature: use DAFlow for both features and images + use_light_en_de: use shallow encoder and decoder to project the images from RGB to high dimensional space + + """ + + # reference branch inputs model img using self-DAFlow + last_multi_self_offsets = None + # source branch inputs cloth img using cross-DAFlow + last_multi_cross_offsets = None + + if return_all: + results_all = [] + + for i in range(len(source_feats)): + + feat_source = source_feats[len(source_feats) - 1 - i] + feat_ref = reference_feats[len(reference_feats) - 1 - i] + B, C, H, W = feat_source.size() + + # Pre-DAWarp for Pyramid feature + if last_multi_cross_offsets is not None and warp_feature: + att_source_feat = DAWarp(feat_source, last_multi_cross_offsets, + cross_att_maps, self.k, self.out_ch) + att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets, + self_att_maps, self.k, self.out_ch) + else: + att_source_feat = feat_source + att_reference_feat = feat_ref + # Cross-MFE + input_feat = torch.cat([att_source_feat, feat_ref], 1) + offsets_att = self.Cross_MFEs[i](input_feat) + cross_att_maps = F.softmax( + offsets_att[:, self.k * 2:, :, :], dim=1) + offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape( + -1, 2, H, W)) + if last_multi_cross_offsets is not None: + offsets = F.grid_sample( + last_multi_cross_offsets, + offsets, + mode='bilinear', + padding_mode='border') + else: + offsets = offsets.permute(0, 3, 1, 2) + last_multi_cross_offsets = offsets + att_source_feat = DAWarp(feat_source, last_multi_cross_offsets, + cross_att_maps, self.k, self.out_ch) + + # Self-MFE + input_feat = torch.cat([att_source_feat, att_reference_feat], 1) + offsets_att = self.Self_MFEs[i](input_feat) + self_att_maps = F.softmax(offsets_att[:, self.k * 2:, :, :], dim=1) + offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape( + -1, 2, H, W)) + if last_multi_self_offsets is not None: + offsets = F.grid_sample( + last_multi_self_offsets, + offsets, + mode='bilinear', + padding_mode='border') + else: + offsets = offsets.permute(0, 3, 1, 2) + last_multi_self_offsets = offsets + att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets, + self_att_maps, self.k, self.out_ch) + + # Refine-MFE + input_feat = torch.cat([att_source_feat, att_reference_feat], 1) + offsets_att = self.Refine_MFEs[i](input_feat) + att_maps = F.softmax(offsets_att[:, self.k * 4:, :, :], dim=1) + cross_offsets = apply_offset( + offsets_att[:, :self.k * 2, :, :].reshape(-1, 2, H, W)) + self_offsets = apply_offset( + offsets_att[:, + self.k * 2:self.k * 4, :, :].reshape(-1, 2, H, W)) + last_multi_cross_offsets = F.grid_sample( + last_multi_cross_offsets, + cross_offsets, + mode='bilinear', + padding_mode='border') + last_multi_self_offsets = F.grid_sample( + last_multi_self_offsets, + self_offsets, + mode='bilinear', + padding_mode='border') + + # Upsampling + last_multi_cross_offsets = F.interpolate( + last_multi_cross_offsets, scale_factor=2, mode='bilinear') + last_multi_self_offsets = F.interpolate( + last_multi_self_offsets, scale_factor=2, mode='bilinear') + self_att_maps = F.interpolate( + att_maps[:, :self.k, :, :], scale_factor=2, mode='bilinear') + cross_att_maps = F.interpolate( + att_maps[:, self.k:, :, :], scale_factor=2, mode='bilinear') + + # Post-DAWarp for source and reference images + if return_all: + cur_source_image = F.interpolate( + source_image, (H * 2, W * 2), mode='bilinear') + cur_reference_image = F.interpolate( + reference_image, (H * 2, W * 2), mode='bilinear') + if use_light_en_de: + cur_source_image = self.lights_encoder(cur_source_image) + cur_reference_image = self.lights_encoder( + cur_reference_image) + # the feat dim in light encoder is 64 + warp_att_source_image = DAWarp(cur_source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 64) + warp_att_reference_image = DAWarp(cur_reference_image, + last_multi_self_offsets, + self_att_maps, self.k, + 64) + result_tryon = self.lights_decoder( + warp_att_source_image + warp_att_reference_image) + else: + warp_att_source_image = DAWarp(cur_source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 3) + warp_att_reference_image = DAWarp(cur_reference_image, + last_multi_self_offsets, + self_att_maps, self.k, 3) + result_tryon = warp_att_source_image + warp_att_reference_image + results_all.append(result_tryon) + + last_multi_self_offsets = F.interpolate( + last_multi_self_offsets, + reference_image.size()[2:], + mode='bilinear') + last_multi_cross_offsets = F.interpolate( + last_multi_cross_offsets, source_image.size()[2:], mode='bilinear') + self_att_maps = F.interpolate( + self_att_maps, reference_image.size()[2:], mode='bilinear') + cross_att_maps = F.interpolate( + cross_att_maps, source_image.size()[2:], mode='bilinear') + if use_light_en_de: + source_image = self.lights_encoder(source_image) + reference_image = self.lights_encoder(reference_image) + warp_att_source_image = DAWarp(source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 64) + warp_att_reference_image = DAWarp(reference_image, + last_multi_self_offsets, + self_att_maps, self.k, 64) + result_tryon = self.lights_decoder(warp_att_source_image + + warp_att_reference_image) + else: + warp_att_source_image = DAWarp(source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 3) + warp_att_reference_image = DAWarp(reference_image, + last_multi_self_offsets, + self_att_maps, self.k, 3) + result_tryon = warp_att_source_image + warp_att_reference_image + + if return_all: + return result_tryon, return_all + return result_tryon + + +class SDAFNet_Tryon(nn.Module): + + def __init__(self, ref_in_channel, source_in_channel=3, head_nums=6): + super(SDAFNet_Tryon, self).__init__() + num_filters = [64, 128, 256, 256, 256] + self.source_features = FeatureEncoder(source_in_channel, num_filters) + self.reference_features = FeatureEncoder(ref_in_channel, num_filters) + self.source_FPN = RefinePyramid(num_filters) + self.reference_FPN = RefinePyramid(num_filters) + self.dafnet = DAFlowNet(len(num_filters), head_nums=head_nums) + + def forward(self, + ref_input, + source_image, + ref_image, + use_light_en_de=True, + return_all=False, + warp_feature=True): + reference_feats = self.reference_FPN( + self.reference_features(ref_input)) + source_feats = self.source_FPN(self.source_features(source_image)) + result = self.dafnet( + source_image, + ref_image, + source_feats, + reference_feats, + use_light_en_de=use_light_en_de, + return_all=return_all, + warp_feature=warp_feature) + return result diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 9794f53e..99463385 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -285,5 +285,10 @@ TASK_OUTPUTS = { # { # "output_pcm": {"input_label" : np.ndarray with shape [D]} # } - Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM] + Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], + # virtual_tryon result for a single sample + # { + # "output_img": np.ndarray with shape [height, width, 3] + # } + Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG] } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6891ae08..3c23e15e 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -70,6 +70,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_imagen_text-to-image-synthesis_tiny'), + Tasks.virtual_tryon: (Pipelines.virtual_tryon, + 'damo/cv_daflow_virtual-tryon_base'), Tasks.image_colorization: (Pipelines.image_colorization, 'damo/cv_unet_image-colorization'), Tasks.style_transfer: (Pipelines.style_transfer, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 75a85da3..c1c1acdb 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -6,6 +6,7 @@ try: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline + from .virtual_tryon_pipeline import VirtualTryonPipeline from .image_colorization_pipeline import ImageColorizationPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline diff --git a/modelscope/pipelines/cv/virtual_tryon_pipeline.py b/modelscope/pipelines/cv/virtual_tryon_pipeline.py new file mode 100644 index 00000000..5d849ba2 --- /dev/null +++ b/modelscope/pipelines/cv/virtual_tryon_pipeline.py @@ -0,0 +1,124 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Union + +import cv2 +import numpy as np +import PIL +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon +from modelscope.outputs import TASK_OUTPUTS, OutputKeys +from modelscope.pipelines.util import is_model, is_official_hub_path +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from ..base import Pipeline +from ..builder import PIPELINES + + +@PIPELINES.register_module( + Tasks.virtual_tryon, module_name=Pipelines.virtual_tryon) +class VirtualTryonPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device) + local_model_dir = model + if osp.exists(model): + local_model_dir = model + else: + local_model_dir = snapshot_download(model) + self.local_path = local_model_dir + src_params = torch.load( + osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), 'cpu') + load_pretrained(self.model, src_params) + self.model = self.model.eval() + self.size = 192 + self.test_transforms = transforms.Compose([ + transforms.Resize(self.size, interpolation=2), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(input['masked_model'], str): + img_agnostic = load_image(input['masked_model']) + pose = load_image(input['pose']) + cloth_img = load_image(input['cloth']) + elif isinstance(input['masked_model'], PIL.Image.Image): + img_agnostic = img_agnostic.convert('RGB') + pose = pose.convert('RGB') + cloth_img = cloth_img.convert('RGB') + elif isinstance(input['masked_model'], np.ndarray): + if len(input.shape) == 2: + img_agnostic = cv2.cvtColor(img_agnostic, cv2.COLOR_GRAY2BGR) + pose = cv2.cvtColor(pose, cv2.COLOR_GRAY2BGR) + cloth_img = cv2.cvtColor(cloth_img, cv2.COLOR_GRAY2BGR) + img_agnostic = Image.fromarray( + img_agnostic[:, :, ::-1].astype('uint8')).convert('RGB') + pose = Image.fromarray( + pose[:, :, ::-1].astype('uint8')).convert('RGB') + cloth_img = Image.fromarray( + cloth_img[:, :, ::-1].astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + img_agnostic = self.test_transforms(img_agnostic) + pose = self.test_transforms(pose) + cloth_img = self.test_transforms(cloth_img) + inputs = { + 'masked_model': img_agnostic.unsqueeze(0), + 'pose': pose.unsqueeze(0), + 'cloth': cloth_img.unsqueeze(0) + } + return inputs + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + img_agnostic = inputs['masked_model'].to(self.device) + pose = inputs['pose'].to(self.device) + cloth_img = inputs['cloth'].to(self.device) + ref_input = torch.cat((pose, img_agnostic), dim=1) + tryon_result = self.model(ref_input, cloth_img, img_agnostic) + return {OutputKeys.OUTPUT_IMG: tryon_result} + + def postprocess(self, outputs: Dict[str, Any]) -> Dict[str, Any]: + tryon_result = outputs[OutputKeys.OUTPUT_IMG].permute(0, 2, 3, + 1).squeeze(0) + tryon_result = tryon_result.add(1.).div(2.).mul(255).data.cpu().numpy() + outputs[OutputKeys.OUTPUT_IMG] = tryon_result + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4d43c3b8..977160d9 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -27,6 +27,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' video_embedding = 'video-embedding' + virtual_tryon = 'virtual-tryon' image_colorization = 'image-colorization' face_image_generation = 'face-image-generation' image_super_resolution = 'image-super-resolution' diff --git a/tests/pipelines/test_virtual_tryon.py b/tests/pipelines/test_virtual_tryon.py new file mode 100644 index 00000000..324dc070 --- /dev/null +++ b/tests/pipelines/test_virtual_tryon.py @@ -0,0 +1,36 @@ +import sys +import unittest + +import cv2 +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class VirtualTryonTest(unittest.TestCase): + model_id = 'damo/cv_daflow_virtual-tryon_base' + input_imgs = { + 'masked_model': 'data/test/images/virtual_tryon_model.jpg', + 'pose': 'data/test/images/virtual_tryon_pose.jpg', + 'cloth': 'data/test/images/virtual_tryon_cloth.jpg' + } + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_virtual_tryon = pipeline( + task=Tasks.virtual_tryon, model=self.model_id) + img = pipeline_virtual_tryon(self.input_imgs)[OutputKeys.OUTPUT_IMG] + cv2.imwrite('demo.jpg', img[:, :, ::-1]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_name_default_model(self): + pipeline_virtual_tryon = pipeline(task=Tasks.virtual_tryon) + img = pipeline_virtual_tryon(self.input_imgs)[OutputKeys.OUTPUT_IMG] + cv2.imwrite('demo.jpg', img[:, :, ::-1]) + + +if __name__ == '__main__': + unittest.main()