Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9526987master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:7e4cbf844cd16a892a7d2f2764b1537c346675d3b0145016d6836441ba907366 | |||
| size 9195 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:33b3d3076e191fa92511bf69fa76e1222b3b3be0049e711c948a1218b587510c | |||
| size 4805 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:99c2b02a927b86ff194287ea4c5a05349dd800cff2b523212d1dad378c252feb | |||
| size 103334 | |||
| @@ -77,6 +77,7 @@ class Pipelines(object): | |||
| face_image_generation = 'gan-face-image-generation' | |||
| style_transfer = 'AAMS-style-transfer' | |||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
| image2image_translation = 'image-to-image-translation' | |||
| live_category = 'live-category' | |||
| video_category = 'video-category' | |||
| @@ -0,0 +1 @@ | |||
| from .transforms import * # noqa F403 | |||
| @@ -0,0 +1,121 @@ | |||
| import math | |||
| import random | |||
| import torchvision.transforms.functional as TF | |||
| from PIL import Image, ImageFilter | |||
| __all__ = [ | |||
| 'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate', | |||
| 'RandomGaussianBlur', 'RandomCrop' | |||
| ] | |||
| class Identity(object): | |||
| def __call__(self, *args): | |||
| if len(args) == 0: | |||
| return None | |||
| elif len(args) == 1: | |||
| return args[0] | |||
| else: | |||
| return args | |||
| class PadToSquare(object): | |||
| def __init__(self, fill=(255, 255, 255)): | |||
| self.fill = fill | |||
| def __call__(self, img): | |||
| w, h = img.size | |||
| if w != h: | |||
| if w > h: | |||
| t = (w - h) // 2 | |||
| b = w - h - t | |||
| padding = (0, t, 0, b) | |||
| else: | |||
| left = (h - w) // 2 | |||
| right = h - w - l | |||
| padding = (left, 0, right, 0) | |||
| img = TF.pad(img, padding, fill=self.fill) | |||
| return img | |||
| class RandomScale(object): | |||
| def __init__(self, | |||
| min_scale=0.5, | |||
| max_scale=2.0, | |||
| min_ratio=0.8, | |||
| max_ratio=1.25): | |||
| self.min_scale = min_scale | |||
| self.max_scale = max_scale | |||
| self.min_ratio = min_ratio | |||
| self.max_ratio = max_ratio | |||
| def __call__(self, img): | |||
| w, h = img.size | |||
| scale = 2**random.uniform( | |||
| math.log2(self.min_scale), math.log2(self.max_scale)) | |||
| ratio = 2**random.uniform( | |||
| math.log2(self.min_ratio), math.log2(self.max_ratio)) | |||
| ow = int(w * scale * math.sqrt(ratio)) | |||
| oh = int(h * scale / math.sqrt(ratio)) | |||
| img = img.resize((ow, oh), Image.BILINEAR) | |||
| return img | |||
| class RandomRotate(object): | |||
| def __init__(self, | |||
| min_angle=-10.0, | |||
| max_angle=10.0, | |||
| padding=(255, 255, 255), | |||
| p=0.5): | |||
| self.min_angle = min_angle | |||
| self.max_angle = max_angle | |||
| self.padding = padding | |||
| self.p = p | |||
| def __call__(self, img): | |||
| if random.random() < self.p: | |||
| angle = random.uniform(self.min_angle, self.max_angle) | |||
| img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding) | |||
| return img | |||
| class RandomGaussianBlur(object): | |||
| def __init__(self, radius=5, p=0.5): | |||
| self.radius = radius | |||
| self.p = p | |||
| def __call__(self, img): | |||
| if random.random() < self.p: | |||
| img = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
| return img | |||
| class RandomCrop(object): | |||
| def __init__(self, size, padding=(255, 255, 255)): | |||
| self.size = size | |||
| self.padding = padding | |||
| def __call__(self, img): | |||
| # pad | |||
| w, h = img.size | |||
| pad_w = max(0, self.size - w) | |||
| pad_h = max(0, self.size - h) | |||
| if pad_w > 0 or pad_h > 0: | |||
| half_w = pad_w // 2 | |||
| half_h = pad_h // 2 | |||
| pad = (half_w, half_h, pad_w - half_w, pad_h - half_h) | |||
| img = TF.pad(img, pad, fill=self.padding) | |||
| # crop | |||
| w, h = img.size | |||
| x1 = random.randint(0, w - self.size) | |||
| y1 = random.randint(0, h - self.size) | |||
| img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) | |||
| return img | |||
| @@ -0,0 +1,323 @@ | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| __all__ = ['UNet'] | |||
| def sinusoidal_embedding(timesteps, dim): | |||
| # check input | |||
| half = dim // 2 | |||
| timesteps = timesteps.float() | |||
| # compute sinusoidal embedding | |||
| sinusoid = torch.outer( | |||
| timesteps, torch.pow(10000, | |||
| -torch.arange(half).to(timesteps).div(half))) | |||
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |||
| if dim % 2 != 0: | |||
| x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |||
| return x | |||
| class Resample(nn.Module): | |||
| def __init__(self, scale_factor=1.0): | |||
| assert scale_factor in [0.5, 1.0, 2.0] | |||
| super(Resample, self).__init__() | |||
| self.scale_factor = scale_factor | |||
| def forward(self, x): | |||
| if self.scale_factor == 2.0: | |||
| x = F.interpolate(x, scale_factor=2, mode='nearest') | |||
| elif self.scale_factor == 0.5: | |||
| x = F.avg_pool2d(x, kernel_size=2, stride=2) | |||
| return x | |||
| class ResidualBlock(nn.Module): | |||
| def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0): | |||
| super(ResidualBlock, self).__init__() | |||
| self.in_dim = in_dim | |||
| self.embed_dim = embed_dim | |||
| self.out_dim = out_dim | |||
| # layers | |||
| self.layer1 = nn.Sequential( | |||
| nn.GroupNorm(32, in_dim), nn.SiLU(), | |||
| nn.Conv2d(in_dim, out_dim, 3, padding=1)) | |||
| self.embedding = nn.Sequential(nn.SiLU(), | |||
| nn.Linear(embed_dim, out_dim)) | |||
| self.layer2 = nn.Sequential( | |||
| nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |||
| nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |||
| self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( | |||
| in_dim, out_dim, 1) | |||
| # zero out the last layer params | |||
| nn.init.zeros_(self.layer2[-1].weight) | |||
| def forward(self, x, y): | |||
| identity = x | |||
| x = self.layer1(x) | |||
| x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1) | |||
| x = self.layer2(x) | |||
| x = x + self.shortcut(identity) | |||
| return x | |||
| class MultiHeadAttention(nn.Module): | |||
| def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0): | |||
| assert dim % num_heads == 0 | |||
| assert context_dim is None or context_dim % num_heads == 0 | |||
| context_dim = context_dim or dim | |||
| super(MultiHeadAttention, self).__init__() | |||
| self.dim = dim | |||
| self.context_dim = context_dim | |||
| self.num_heads = num_heads | |||
| self.head_dim = dim // num_heads | |||
| self.scale = math.pow(self.head_dim, -0.25) | |||
| # layers | |||
| self.q = nn.Linear(dim, dim, bias=False) | |||
| self.k = nn.Linear(context_dim, dim, bias=False) | |||
| self.v = nn.Linear(context_dim, dim, bias=False) | |||
| self.o = nn.Linear(dim, dim) | |||
| self.dropout = nn.Dropout(dropout) | |||
| def forward(self, x, context=None): | |||
| # check inputs | |||
| context = x if context is None else context | |||
| b, n, c = x.size(0), self.num_heads, self.head_dim | |||
| # compute query, key, value | |||
| q = self.q(x).view(b, -1, n, c) | |||
| k = self.k(context).view(b, -1, n, c) | |||
| v = self.v(context).view(b, -1, n, c) | |||
| # compute attention | |||
| attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) | |||
| attn = F.softmax(attn, dim=-1) | |||
| attn = self.dropout(attn) | |||
| # gather context | |||
| x = torch.einsum('bnij,bjnc->binc', attn, v) | |||
| x = x.reshape(b, -1, n * c) | |||
| # output | |||
| x = self.o(x) | |||
| x = self.dropout(x) | |||
| return x | |||
| class GLU(nn.Module): | |||
| def __init__(self, in_dim, out_dim): | |||
| super(GLU, self).__init__() | |||
| self.in_dim = in_dim | |||
| self.out_dim = out_dim | |||
| self.proj = nn.Linear(in_dim, out_dim * 2) | |||
| def forward(self, x): | |||
| x, gate = self.proj(x).chunk(2, dim=-1) | |||
| return x * F.gelu(gate) | |||
| class TransformerBlock(nn.Module): | |||
| def __init__(self, dim, context_dim, num_heads, dropout=0.0): | |||
| super(TransformerBlock, self).__init__() | |||
| self.dim = dim | |||
| self.context_dim = context_dim | |||
| self.num_heads = num_heads | |||
| self.head_dim = dim // num_heads | |||
| # input | |||
| self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True) | |||
| self.conv1 = nn.Conv2d(dim, dim, 1) | |||
| # self attention | |||
| self.norm2 = nn.LayerNorm(dim) | |||
| self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout) | |||
| # cross attention | |||
| self.norm3 = nn.LayerNorm(dim) | |||
| self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads, | |||
| dropout) | |||
| # ffn | |||
| self.norm4 = nn.LayerNorm(dim) | |||
| self.ffn = nn.Sequential( | |||
| GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim)) | |||
| # output | |||
| self.conv2 = nn.Conv2d(dim, dim, 1) | |||
| # zero out the last layer params | |||
| nn.init.zeros_(self.conv2.weight) | |||
| def forward(self, x, context): | |||
| b, c, h, w = x.size() | |||
| identity = x | |||
| # input | |||
| x = self.norm1(x) | |||
| x = self.conv1(x).view(b, c, -1).transpose(1, 2) | |||
| # attention | |||
| x = x + self.self_attn(self.norm2(x)) | |||
| x = x + self.cross_attn(self.norm3(x), context) | |||
| x = x + self.ffn(self.norm4(x)) | |||
| # output | |||
| x = x.transpose(1, 2).view(b, c, h, w) | |||
| x = self.conv2(x) | |||
| return x + identity | |||
| class UNet(nn.Module): | |||
| def __init__(self, | |||
| resolution=64, | |||
| in_dim=3, | |||
| dim=192, | |||
| context_dim=512, | |||
| out_dim=3, | |||
| dim_mult=[1, 2, 3, 5], | |||
| num_heads=1, | |||
| head_dim=None, | |||
| num_res_blocks=2, | |||
| attn_scales=[1 / 2, 1 / 4, 1 / 8], | |||
| num_classes=1001, | |||
| dropout=0.0): | |||
| embed_dim = dim * 4 | |||
| super(UNet, self).__init__() | |||
| self.resolution = resolution | |||
| self.in_dim = in_dim | |||
| self.dim = dim | |||
| self.context_dim = context_dim | |||
| self.out_dim = out_dim | |||
| self.dim_mult = dim_mult | |||
| self.num_heads = num_heads | |||
| self.head_dim = head_dim | |||
| self.num_res_blocks = num_res_blocks | |||
| self.attn_scales = attn_scales | |||
| self.num_classes = num_classes | |||
| # params | |||
| enc_dims = [dim * u for u in [1] + dim_mult] | |||
| dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||
| shortcut_dims = [] | |||
| scale = 1.0 | |||
| # embeddings | |||
| self.time_embedding = nn.Sequential( | |||
| nn.Linear(dim, embed_dim), nn.SiLU(), | |||
| nn.Linear(embed_dim, embed_dim)) | |||
| self.label_embedding = nn.Embedding(num_classes, context_dim) | |||
| # encoder | |||
| self.encoder = nn.ModuleList( | |||
| [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |||
| shortcut_dims.append(dim) | |||
| for i, (in_dim, | |||
| out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |||
| for j in range(num_res_blocks): | |||
| # residual (+attention) blocks | |||
| block = nn.ModuleList( | |||
| [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) | |||
| if scale in attn_scales: | |||
| block.append( | |||
| TransformerBlock(out_dim, context_dim, num_heads)) | |||
| in_dim = out_dim | |||
| self.encoder.append(block) | |||
| shortcut_dims.append(out_dim) | |||
| # downsample | |||
| if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |||
| self.encoder.append( | |||
| nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)) | |||
| shortcut_dims.append(out_dim) | |||
| scale /= 2.0 | |||
| # middle | |||
| self.middle = nn.ModuleList([ | |||
| ResidualBlock(out_dim, embed_dim, out_dim, dropout), | |||
| TransformerBlock(out_dim, context_dim, num_heads), | |||
| ResidualBlock(out_dim, embed_dim, out_dim, dropout) | |||
| ]) | |||
| # decoder | |||
| self.decoder = nn.ModuleList() | |||
| for i, (in_dim, | |||
| out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): | |||
| for j in range(num_res_blocks + 1): | |||
| # residual (+attention) blocks | |||
| block = nn.ModuleList([ | |||
| ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, | |||
| out_dim, dropout) | |||
| ]) | |||
| if scale in attn_scales: | |||
| block.append( | |||
| TransformerBlock(out_dim, context_dim, num_heads, | |||
| dropout)) | |||
| in_dim = out_dim | |||
| # upsample | |||
| if i != len(dim_mult) - 1 and j == num_res_blocks: | |||
| block.append( | |||
| nn.Sequential( | |||
| Resample(scale_factor=2.0), | |||
| nn.Conv2d(out_dim, out_dim, 3, padding=1))) | |||
| scale *= 2.0 | |||
| self.decoder.append(block) | |||
| # head | |||
| self.head = nn.Sequential( | |||
| nn.GroupNorm(32, out_dim), nn.SiLU(), | |||
| nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) | |||
| # zero out the last layer params | |||
| nn.init.zeros_(self.head[-1].weight) | |||
| def forward(self, x, t, y, concat=None): | |||
| # embeddings | |||
| if concat is not None: | |||
| x = torch.cat([x, concat], dim=1) | |||
| t = self.time_embedding(sinusoidal_embedding(t, self.dim)) | |||
| y = self.label_embedding(y) | |||
| # encoder | |||
| xs = [] | |||
| for block in self.encoder: | |||
| x = self._forward_single(block, x, t, y) | |||
| xs.append(x) | |||
| # middle | |||
| for block in self.middle: | |||
| x = self._forward_single(block, x, t, y) | |||
| # decoder | |||
| for block in self.decoder: | |||
| x = torch.cat([x, xs.pop()], dim=1) | |||
| x = self._forward_single(block, x, t, y) | |||
| # head | |||
| x = self.head(x) | |||
| return x | |||
| def _forward_single(self, module, x, t, y): | |||
| if isinstance(module, ResidualBlock): | |||
| x = module(x, t) | |||
| elif isinstance(module, TransformerBlock): | |||
| x = module(x, y) | |||
| elif isinstance(module, nn.ModuleList): | |||
| for block in module: | |||
| x = self._forward_single(block, x, t, y) | |||
| else: | |||
| x = module(x) | |||
| return x | |||
| @@ -0,0 +1,2 @@ | |||
| from .autoencoder import * # noqa F403 | |||
| from .clip import * # noqa F403 | |||
| @@ -0,0 +1,412 @@ | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| __all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator'] | |||
| def group_norm(dim): | |||
| return nn.GroupNorm(32, dim, eps=1e-6, affine=True) | |||
| class Resample(nn.Module): | |||
| def __init__(self, dim, scale_factor): | |||
| super(Resample, self).__init__() | |||
| self.dim = dim | |||
| self.scale_factor = scale_factor | |||
| # layers | |||
| if scale_factor == 2.0: | |||
| self.resample = nn.Sequential( | |||
| nn.Upsample(scale_factor=scale_factor, mode='nearest'), | |||
| nn.Conv2d(dim, dim, 3, padding=1)) | |||
| elif scale_factor == 0.5: | |||
| self.resample = nn.Sequential( | |||
| nn.ZeroPad2d((0, 1, 0, 1)), | |||
| nn.Conv2d(dim, dim, 3, stride=2, padding=0)) | |||
| else: | |||
| self.resample = nn.Identity() | |||
| def forward(self, x): | |||
| return self.resample(x) | |||
| class ResidualBlock(nn.Module): | |||
| def __init__(self, in_dim, out_dim, dropout=0.0): | |||
| super(ResidualBlock, self).__init__() | |||
| self.in_dim = in_dim | |||
| self.out_dim = out_dim | |||
| # layers | |||
| self.residual = nn.Sequential( | |||
| group_norm(in_dim), nn.SiLU(), | |||
| nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim), | |||
| nn.SiLU(), nn.Dropout(dropout), | |||
| nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |||
| self.shortcut = nn.Conv2d(in_dim, out_dim, | |||
| 1) if in_dim != out_dim else nn.Identity() | |||
| # zero out the last layer params | |||
| nn.init.zeros_(self.residual[-1].weight) | |||
| def forward(self, x): | |||
| return self.residual(x) + self.shortcut(x) | |||
| class AttentionBlock(nn.Module): | |||
| def __init__(self, dim): | |||
| super(AttentionBlock, self).__init__() | |||
| self.dim = dim | |||
| self.scale = math.pow(dim, -0.25) | |||
| # layers | |||
| self.norm = group_norm(dim) | |||
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |||
| self.proj = nn.Conv2d(dim, dim, 1) | |||
| # zero out the last layer params | |||
| nn.init.zeros_(self.proj.weight) | |||
| def forward(self, x): | |||
| identity = x | |||
| b, c, h, w = x.size() | |||
| # compute query, key, value | |||
| x = self.norm(x) | |||
| q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1) | |||
| # compute attention | |||
| attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale) | |||
| attn = F.softmax(attn, dim=-1) | |||
| # gather context | |||
| x = torch.einsum('bij,bcj->bci', attn, v) | |||
| x = x.reshape(b, c, h, w) | |||
| # output | |||
| x = self.proj(x) | |||
| return x + identity | |||
| class Encoder(nn.Module): | |||
| def __init__(self, | |||
| dim=128, | |||
| z_dim=3, | |||
| dim_mult=[1, 2, 4], | |||
| num_res_blocks=2, | |||
| attn_scales=[], | |||
| dropout=0.0): | |||
| super(Encoder, self).__init__() | |||
| self.dim = dim | |||
| self.z_dim = z_dim | |||
| self.dim_mult = dim_mult | |||
| self.num_res_blocks = num_res_blocks | |||
| self.attn_scales = attn_scales | |||
| # params | |||
| dims = [dim * u for u in [1] + dim_mult] | |||
| scale = 1.0 | |||
| # init block | |||
| self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1) | |||
| # downsample blocks | |||
| downsamples = [] | |||
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |||
| # residual (+attention) blocks | |||
| for _ in range(num_res_blocks): | |||
| downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |||
| if scale in attn_scales: | |||
| downsamples.append(AttentionBlock(out_dim)) | |||
| in_dim = out_dim | |||
| # downsample block | |||
| if i != len(dim_mult) - 1: | |||
| downsamples.append(Resample(out_dim, scale_factor=0.5)) | |||
| scale /= 2.0 | |||
| self.downsamples = nn.Sequential(*downsamples) | |||
| # middle blocks | |||
| self.middle = nn.Sequential( | |||
| ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), | |||
| ResidualBlock(out_dim, out_dim, dropout)) | |||
| # output blocks | |||
| self.head = nn.Sequential( | |||
| group_norm(out_dim), nn.SiLU(), | |||
| nn.Conv2d(out_dim, z_dim, 3, padding=1)) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.downsamples(x) | |||
| x = self.middle(x) | |||
| x = self.head(x) | |||
| return x | |||
| class Decoder(nn.Module): | |||
| def __init__(self, | |||
| dim=128, | |||
| z_dim=3, | |||
| dim_mult=[1, 2, 4], | |||
| num_res_blocks=2, | |||
| attn_scales=[], | |||
| dropout=0.0): | |||
| super(Decoder, self).__init__() | |||
| self.dim = dim | |||
| self.z_dim = z_dim | |||
| self.dim_mult = dim_mult | |||
| self.num_res_blocks = num_res_blocks | |||
| self.attn_scales = attn_scales | |||
| # params | |||
| dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||
| scale = 1.0 / 2**(len(dim_mult) - 2) | |||
| # init block | |||
| self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1) | |||
| # middle blocks | |||
| self.middle = nn.Sequential( | |||
| ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), | |||
| ResidualBlock(dims[0], dims[0], dropout)) | |||
| # upsample blocks | |||
| upsamples = [] | |||
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |||
| # residual (+attention) blocks | |||
| for _ in range(num_res_blocks + 1): | |||
| upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |||
| if scale in attn_scales: | |||
| upsamples.append(AttentionBlock(out_dim)) | |||
| in_dim = out_dim | |||
| # upsample block | |||
| if i != len(dim_mult) - 1: | |||
| upsamples.append(Resample(out_dim, scale_factor=2.0)) | |||
| scale *= 2.0 | |||
| self.upsamples = nn.Sequential(*upsamples) | |||
| # output blocks | |||
| self.head = nn.Sequential( | |||
| group_norm(out_dim), nn.SiLU(), | |||
| nn.Conv2d(out_dim, 3, 3, padding=1)) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.middle(x) | |||
| x = self.upsamples(x) | |||
| x = self.head(x) | |||
| return x | |||
| class VectorQuantizer(nn.Module): | |||
| def __init__(self, codebook_size=8192, z_dim=3, beta=0.25): | |||
| super(VectorQuantizer, self).__init__() | |||
| self.codebook_size = codebook_size | |||
| self.z_dim = z_dim | |||
| self.beta = beta | |||
| # init codebook | |||
| eps = math.sqrt(1.0 / codebook_size) | |||
| self.codebook = nn.Parameter( | |||
| torch.empty(codebook_size, z_dim).uniform_(-eps, eps)) | |||
| def forward(self, z): | |||
| # preprocess | |||
| b, c, h, w = z.size() | |||
| flatten = z.permute(0, 2, 3, 1).reshape(-1, c) | |||
| # quantization | |||
| with torch.no_grad(): | |||
| tokens = torch.cdist(flatten, self.codebook).argmin(dim=1) | |||
| quantized = F.embedding(tokens, | |||
| self.codebook).view(b, h, w, | |||
| c).permute(0, 3, 1, 2) | |||
| # compute loss | |||
| codebook_loss = F.mse_loss(quantized, z.detach()) | |||
| commitment_loss = F.mse_loss(quantized.detach(), z) | |||
| loss = codebook_loss + self.beta * commitment_loss | |||
| # perplexity | |||
| counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) | |||
| # dist.all_reduce(counts) | |||
| p = counts / counts.sum() | |||
| perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) | |||
| # postprocess | |||
| tokens = tokens.view(b, h, w) | |||
| quantized = z + (quantized - z).detach() | |||
| return quantized, tokens, loss, perplexity | |||
| class VQAutoencoder(nn.Module): | |||
| def __init__(self, | |||
| dim=128, | |||
| z_dim=3, | |||
| dim_mult=[1, 2, 4], | |||
| num_res_blocks=2, | |||
| attn_scales=[], | |||
| dropout=0.0, | |||
| codebook_size=8192, | |||
| beta=0.25): | |||
| super(VQAutoencoder, self).__init__() | |||
| self.dim = dim | |||
| self.z_dim = z_dim | |||
| self.dim_mult = dim_mult | |||
| self.num_res_blocks = num_res_blocks | |||
| self.attn_scales = attn_scales | |||
| self.codebook_size = codebook_size | |||
| self.beta = beta | |||
| # blocks | |||
| self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks, | |||
| attn_scales, dropout) | |||
| self.conv1 = nn.Conv2d(z_dim, z_dim, 1) | |||
| self.quantizer = VectorQuantizer(codebook_size, z_dim, beta) | |||
| self.conv2 = nn.Conv2d(z_dim, z_dim, 1) | |||
| self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, | |||
| attn_scales, dropout) | |||
| def forward(self, x): | |||
| z = self.encoder(x) | |||
| z = self.conv1(z) | |||
| z, tokens, loss, perplexity = self.quantizer(z) | |||
| z = self.conv2(z) | |||
| x = self.decoder(z) | |||
| return x, tokens, loss, perplexity | |||
| def encode(self, imgs): | |||
| z = self.encoder(imgs) | |||
| z = self.conv1(z) | |||
| return z | |||
| def decode(self, z): | |||
| r"""Absort the quantizer in the decoder. | |||
| """ | |||
| z = self.quantizer(z)[0] | |||
| z = self.conv2(z) | |||
| imgs = self.decoder(z) | |||
| return imgs | |||
| @torch.no_grad() | |||
| def encode_to_tokens(self, imgs): | |||
| # preprocess | |||
| z = self.encoder(imgs) | |||
| z = self.conv1(z) | |||
| # quantization | |||
| b, c, h, w = z.size() | |||
| flatten = z.permute(0, 2, 3, 1).reshape(-1, c) | |||
| tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1) | |||
| return tokens.view(b, -1) | |||
| @torch.no_grad() | |||
| def decode_from_tokens(self, tokens): | |||
| # dequantization | |||
| z = F.embedding(tokens, self.quantizer.codebook) | |||
| # postprocess | |||
| b, l, c = z.size() | |||
| h = w = int(math.sqrt(l)) | |||
| z = z.view(b, h, w, c).permute(0, 3, 1, 2) | |||
| z = self.conv2(z) | |||
| imgs = self.decoder(z) | |||
| return imgs | |||
| class KLAutoencoder(nn.Module): | |||
| def __init__(self, | |||
| dim=128, | |||
| z_dim=4, | |||
| dim_mult=[1, 2, 4, 4], | |||
| num_res_blocks=2, | |||
| attn_scales=[], | |||
| dropout=0.0): | |||
| super(KLAutoencoder, self).__init__() | |||
| self.dim = dim | |||
| self.z_dim = z_dim | |||
| self.dim_mult = dim_mult | |||
| self.num_res_blocks = num_res_blocks | |||
| self.attn_scales = attn_scales | |||
| # blocks | |||
| self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks, | |||
| attn_scales, dropout) | |||
| self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1) | |||
| self.conv2 = nn.Conv2d(z_dim, z_dim, 1) | |||
| self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, | |||
| attn_scales, dropout) | |||
| def forward(self, x): | |||
| mu, log_var = self.encode(x) | |||
| z = self.reparameterize(mu, log_var) | |||
| x = self.decode(z) | |||
| return x, mu, log_var | |||
| def encode(self, x): | |||
| x = self.encoder(x) | |||
| mu, log_var = self.conv1(x).chunk(2, dim=1) | |||
| return mu, log_var | |||
| def decode(self, z): | |||
| x = self.conv2(z) | |||
| x = self.decoder(x) | |||
| return x | |||
| def reparameterize(self, mu, log_var): | |||
| std = torch.exp(0.5 * log_var) | |||
| eps = torch.randn_like(std) | |||
| return eps * std + mu | |||
| class PatchDiscriminator(nn.Module): | |||
| def __init__(self, in_dim=3, dim=64, num_layers=3): | |||
| super(PatchDiscriminator, self).__init__() | |||
| self.in_dim = in_dim | |||
| self.dim = dim | |||
| self.num_layers = num_layers | |||
| # params | |||
| dims = [dim * min(8, 2**u) for u in range(num_layers + 1)] | |||
| # layers | |||
| layers = [ | |||
| nn.Conv2d(in_dim, dim, 4, stride=2, padding=1), | |||
| nn.LeakyReLU(0.2) | |||
| ] | |||
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |||
| stride = 1 if i == num_layers - 1 else 2 | |||
| layers += [ | |||
| nn.Conv2d( | |||
| in_dim, out_dim, 4, stride=stride, padding=1, bias=False), | |||
| nn.BatchNorm2d(out_dim), | |||
| nn.LeakyReLU(0.2) | |||
| ] | |||
| layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)] | |||
| self.layers = nn.Sequential(*layers) | |||
| # initialize weights | |||
| self.apply(self.init_weights) | |||
| def forward(self, x): | |||
| return self.layers(x) | |||
| def init_weights(self, m): | |||
| if isinstance(m, nn.Conv2d): | |||
| nn.init.normal_(m.weight, 0.0, 0.02) | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| nn.init.normal_(m.weight, 1.0, 0.02) | |||
| nn.init.zeros_(m.bias) | |||
| @@ -0,0 +1,418 @@ | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather | |||
| __all__ = [ | |||
| 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', | |||
| 'clip_vit_l_14_336px', 'clip_vit_h_16' | |||
| ] | |||
| def to_fp16(m): | |||
| if isinstance(m, (nn.Linear, nn.Conv2d)): | |||
| m.weight.data = m.weight.data.half() | |||
| if m.bias is not None: | |||
| m.bias.data = m.bias.data.half() | |||
| elif hasattr(m, 'head'): | |||
| p = getattr(m, 'head') | |||
| p.data = p.data.half() | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class LayerNorm(nn.LayerNorm): | |||
| r"""Subclass of nn.LayerNorm to handle fp16. | |||
| """ | |||
| def forward(self, x): | |||
| return super(LayerNorm, self).forward(x.float()).type_as(x) | |||
| class SelfAttention(nn.Module): | |||
| def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): | |||
| assert dim % num_heads == 0 | |||
| super(SelfAttention, self).__init__() | |||
| self.dim = dim | |||
| self.num_heads = num_heads | |||
| self.head_dim = dim // num_heads | |||
| self.scale = 1.0 / math.sqrt(self.head_dim) | |||
| # layers | |||
| self.to_qkv = nn.Linear(dim, dim * 3) | |||
| self.attn_dropout = nn.Dropout(attn_dropout) | |||
| self.proj = nn.Linear(dim, dim) | |||
| self.proj_dropout = nn.Dropout(proj_dropout) | |||
| def forward(self, x, mask=None): | |||
| r"""x: [B, L, C]. | |||
| mask: [*, L, L]. | |||
| """ | |||
| b, l, _, n = *x.size(), self.num_heads | |||
| # compute query, key, and value | |||
| q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) | |||
| q = q.reshape(l, b * n, -1).transpose(0, 1) | |||
| k = k.reshape(l, b * n, -1).transpose(0, 1) | |||
| v = v.reshape(l, b * n, -1).transpose(0, 1) | |||
| # compute attention | |||
| attn = self.scale * torch.bmm(q, k.transpose(1, 2)) | |||
| if mask is not None: | |||
| attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) | |||
| attn = F.softmax(attn.float(), dim=-1).type_as(attn) | |||
| attn = self.attn_dropout(attn) | |||
| # gather context | |||
| x = torch.bmm(attn, v) | |||
| x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) | |||
| # output | |||
| x = self.proj(x) | |||
| x = self.proj_dropout(x) | |||
| return x | |||
| class AttentionBlock(nn.Module): | |||
| def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): | |||
| super(AttentionBlock, self).__init__() | |||
| self.dim = dim | |||
| self.num_heads = num_heads | |||
| # layers | |||
| self.norm1 = LayerNorm(dim) | |||
| self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) | |||
| self.norm2 = LayerNorm(dim) | |||
| self.mlp = nn.Sequential( | |||
| nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), | |||
| nn.Dropout(proj_dropout)) | |||
| def forward(self, x, mask=None): | |||
| x = x + self.attn(self.norm1(x), mask) | |||
| x = x + self.mlp(self.norm2(x)) | |||
| return x | |||
| class VisionTransformer(nn.Module): | |||
| def __init__(self, | |||
| image_size=224, | |||
| patch_size=16, | |||
| dim=768, | |||
| out_dim=512, | |||
| num_heads=12, | |||
| num_layers=12, | |||
| attn_dropout=0.0, | |||
| proj_dropout=0.0, | |||
| embedding_dropout=0.0): | |||
| assert image_size % patch_size == 0 | |||
| super(VisionTransformer, self).__init__() | |||
| self.image_size = image_size | |||
| self.patch_size = patch_size | |||
| self.dim = dim | |||
| self.out_dim = out_dim | |||
| self.num_heads = num_heads | |||
| self.num_layers = num_layers | |||
| self.num_patches = (image_size // patch_size)**2 | |||
| # embeddings | |||
| gain = 1.0 / math.sqrt(dim) | |||
| self.patch_embedding = nn.Conv2d( | |||
| 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) | |||
| self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) | |||
| self.pos_embedding = nn.Parameter( | |||
| gain * torch.randn(1, self.num_patches + 1, dim)) | |||
| self.dropout = nn.Dropout(embedding_dropout) | |||
| # transformer | |||
| self.pre_norm = LayerNorm(dim) | |||
| self.transformer = nn.Sequential(*[ | |||
| AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) | |||
| for _ in range(num_layers) | |||
| ]) | |||
| self.post_norm = LayerNorm(dim) | |||
| # head | |||
| self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) | |||
| def forward(self, x): | |||
| b, dtype = x.size(0), self.head.dtype | |||
| x = x.type(dtype) | |||
| # patch-embedding | |||
| x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] | |||
| x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], | |||
| dim=1) | |||
| x = self.dropout(x + self.pos_embedding.type(dtype)) | |||
| x = self.pre_norm(x) | |||
| # transformer | |||
| x = self.transformer(x) | |||
| # head | |||
| x = self.post_norm(x) | |||
| x = torch.mm(x[:, 0, :], self.head) | |||
| return x | |||
| def fp16(self): | |||
| return self.apply(to_fp16) | |||
| class TextTransformer(nn.Module): | |||
| def __init__(self, | |||
| vocab_size, | |||
| text_len, | |||
| dim=512, | |||
| out_dim=512, | |||
| num_heads=8, | |||
| num_layers=12, | |||
| attn_dropout=0.0, | |||
| proj_dropout=0.0, | |||
| embedding_dropout=0.0): | |||
| super(TextTransformer, self).__init__() | |||
| self.vocab_size = vocab_size | |||
| self.text_len = text_len | |||
| self.dim = dim | |||
| self.out_dim = out_dim | |||
| self.num_heads = num_heads | |||
| self.num_layers = num_layers | |||
| # embeddings | |||
| self.token_embedding = nn.Embedding(vocab_size, dim) | |||
| self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) | |||
| self.dropout = nn.Dropout(embedding_dropout) | |||
| # transformer | |||
| self.transformer = nn.ModuleList([ | |||
| AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) | |||
| for _ in range(num_layers) | |||
| ]) | |||
| self.norm = LayerNorm(dim) | |||
| # head | |||
| gain = 1.0 / math.sqrt(dim) | |||
| self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) | |||
| # causal attention mask | |||
| self.register_buffer('attn_mask', | |||
| torch.tril(torch.ones(1, text_len, text_len))) | |||
| def forward(self, x): | |||
| eot, dtype = x.argmax(dim=-1), self.head.dtype | |||
| # embeddings | |||
| x = self.dropout( | |||
| self.token_embedding(x).type(dtype) | |||
| + self.pos_embedding.type(dtype)) | |||
| # transformer | |||
| for block in self.transformer: | |||
| x = block(x, self.attn_mask) | |||
| # head | |||
| x = self.norm(x) | |||
| x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) | |||
| return x | |||
| def fp16(self): | |||
| return self.apply(to_fp16) | |||
| class CLIP(nn.Module): | |||
| def __init__(self, | |||
| embed_dim=512, | |||
| image_size=224, | |||
| patch_size=16, | |||
| vision_dim=768, | |||
| vision_heads=12, | |||
| vision_layers=12, | |||
| vocab_size=49408, | |||
| text_len=77, | |||
| text_dim=512, | |||
| text_heads=8, | |||
| text_layers=12, | |||
| attn_dropout=0.0, | |||
| proj_dropout=0.0, | |||
| embedding_dropout=0.0): | |||
| super(CLIP, self).__init__() | |||
| self.embed_dim = embed_dim | |||
| self.image_size = image_size | |||
| self.patch_size = patch_size | |||
| self.vision_dim = vision_dim | |||
| self.vision_heads = vision_heads | |||
| self.vision_layers = vision_layers | |||
| self.vocab_size = vocab_size | |||
| self.text_len = text_len | |||
| self.text_dim = text_dim | |||
| self.text_heads = text_heads | |||
| self.text_layers = text_layers | |||
| # models | |||
| self.visual = VisionTransformer( | |||
| image_size=image_size, | |||
| patch_size=patch_size, | |||
| dim=vision_dim, | |||
| out_dim=embed_dim, | |||
| num_heads=vision_heads, | |||
| num_layers=vision_layers, | |||
| attn_dropout=attn_dropout, | |||
| proj_dropout=proj_dropout, | |||
| embedding_dropout=embedding_dropout) | |||
| self.textual = TextTransformer( | |||
| vocab_size=vocab_size, | |||
| text_len=text_len, | |||
| dim=text_dim, | |||
| out_dim=embed_dim, | |||
| num_heads=text_heads, | |||
| num_layers=text_layers, | |||
| attn_dropout=attn_dropout, | |||
| proj_dropout=proj_dropout, | |||
| embedding_dropout=embedding_dropout) | |||
| self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) | |||
| def forward(self, imgs, txt_tokens): | |||
| r"""imgs: [B, C, H, W] of torch.float32. | |||
| txt_tokens: [B, T] of torch.long. | |||
| """ | |||
| xi = self.visual(imgs) | |||
| xt = self.textual(txt_tokens) | |||
| # normalize features | |||
| xi = F.normalize(xi, p=2, dim=1) | |||
| xt = F.normalize(xt, p=2, dim=1) | |||
| # gather features from all ranks | |||
| full_xi = ops.diff_all_gather(xi) | |||
| full_xt = ops.diff_all_gather(xt) | |||
| # logits | |||
| scale = self.log_scale.exp() | |||
| logits_i2t = scale * torch.mm(xi, full_xt.t()) | |||
| logits_t2i = scale * torch.mm(xt, full_xi.t()) | |||
| # labels | |||
| labels = torch.arange( | |||
| len(xi) * ops.get_rank(), | |||
| len(xi) * (ops.get_rank() + 1), | |||
| dtype=torch.long, | |||
| device=xi.device) | |||
| return logits_i2t, logits_t2i, labels | |||
| def init_weights(self): | |||
| # embeddings | |||
| nn.init.normal_(self.textual.token_embedding.weight, std=0.02) | |||
| nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) | |||
| # attentions | |||
| for modality in ['visual', 'textual']: | |||
| dim = self.vision_dim if modality == 'visual' else 'textual' | |||
| transformer = getattr(self, modality).transformer | |||
| proj_gain = (1.0 / math.sqrt(dim)) * ( | |||
| 1.0 / math.sqrt(2 * transformer.num_layers)) | |||
| attn_gain = 1.0 / math.sqrt(dim) | |||
| mlp_gain = 1.0 / math.sqrt(2.0 * dim) | |||
| for block in transformer.layers: | |||
| nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) | |||
| nn.init.normal_(block.attn.proj.weight, std=proj_gain) | |||
| nn.init.normal_(block.mlp[0].weight, std=mlp_gain) | |||
| nn.init.normal_(block.mlp[2].weight, std=proj_gain) | |||
| def param_groups(self): | |||
| groups = [{ | |||
| 'params': [ | |||
| p for n, p in self.named_parameters() | |||
| if 'norm' in n or n.endswith('bias') | |||
| ], | |||
| 'weight_decay': | |||
| 0.0 | |||
| }, { | |||
| 'params': [ | |||
| p for n, p in self.named_parameters() | |||
| if not ('norm' in n or n.endswith('bias')) | |||
| ] | |||
| }] | |||
| return groups | |||
| def fp16(self): | |||
| return self.apply(to_fp16) | |||
| def clip_vit_b_32(**kwargs): | |||
| return CLIP( | |||
| embed_dim=512, | |||
| image_size=224, | |||
| patch_size=32, | |||
| vision_dim=768, | |||
| vision_heads=12, | |||
| vision_layers=12, | |||
| text_dim=512, | |||
| text_heads=8, | |||
| text_layers=12, | |||
| **kwargs) | |||
| def clip_vit_b_16(**kwargs): | |||
| return CLIP( | |||
| embed_dim=512, | |||
| image_size=224, | |||
| patch_size=16, | |||
| vision_dim=768, | |||
| vision_heads=12, | |||
| vision_layers=12, | |||
| text_dim=512, | |||
| text_heads=8, | |||
| text_layers=12, | |||
| **kwargs) | |||
| def clip_vit_l_14(**kwargs): | |||
| return CLIP( | |||
| embed_dim=768, | |||
| image_size=224, | |||
| patch_size=14, | |||
| vision_dim=1024, | |||
| vision_heads=16, | |||
| vision_layers=24, | |||
| text_dim=768, | |||
| text_heads=12, | |||
| text_layers=12, | |||
| **kwargs) | |||
| def clip_vit_l_14_336px(**kwargs): | |||
| return CLIP( | |||
| embed_dim=768, | |||
| image_size=336, | |||
| patch_size=14, | |||
| vision_dim=1024, | |||
| vision_heads=16, | |||
| vision_layers=24, | |||
| text_dim=768, | |||
| text_heads=12, | |||
| text_layers=12, | |||
| **kwargs) | |||
| def clip_vit_h_16(**kwargs): | |||
| return CLIP( | |||
| embed_dim=1024, | |||
| image_size=256, | |||
| patch_size=16, | |||
| vision_dim=1280, | |||
| vision_heads=16, | |||
| vision_layers=32, | |||
| text_dim=1024, | |||
| text_heads=16, | |||
| text_layers=24, | |||
| **kwargs) | |||
| @@ -0,0 +1,8 @@ | |||
| from .degradation import * # noqa F403 | |||
| from .diffusion import * # noqa F403 | |||
| from .losses import * # noqa F403 | |||
| from .metrics import * # noqa F403 | |||
| from .random_color import * # noqa F403 | |||
| from .random_mask import * # noqa F403 | |||
| from .svd import * # noqa F403 | |||
| from .utils import * # noqa F403 | |||
| @@ -0,0 +1,663 @@ | |||
| # APPs that facilitate the use of pretrained neural networks. | |||
| import os.path as osp | |||
| import artist.data as data | |||
| import artist.models as models | |||
| import numpy as np | |||
| import torch | |||
| import torch.cuda.amp as amp | |||
| import torch.nn.functional as F | |||
| import torchvision.transforms as T | |||
| from artist import DOWNLOAD_TO_CACHE | |||
| from PIL import Image | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from .utils import parallel, read_image | |||
| __all__ = [ | |||
| 'FeatureExtractor', 'Classifier', 'Text2Image', 'Sole2Shoe', 'ImageParser', | |||
| 'TextImageMatch', 'taobao_feature_extractor', 'singleton_classifier', | |||
| 'orientation_classifier', 'fashion_text2image', 'mindalle_text2image', | |||
| 'sole2shoe', 'sole_parser', 'sod_foreground_parser', | |||
| 'fashion_text_image_match' | |||
| ] | |||
| class ImageFolder(Dataset): | |||
| def __init__(self, paths, transforms=None): | |||
| self.paths = paths | |||
| self.transforms = transforms | |||
| def __getitem__(self, index): | |||
| img = read_image(self.paths[index]) | |||
| if img.mode != 'RGB': | |||
| img = img.convert('RGB') | |||
| if self.transforms is not None: | |||
| img = self.transforms(img) | |||
| return img | |||
| def __len__(self): | |||
| return len(self.paths) | |||
| class FeatureExtractor(object): | |||
| def __init__( | |||
| self, | |||
| model='InceptionV1', | |||
| checkpoint='models/inception-v1/1218shoes.v9_7.140.0.1520000', | |||
| resolution=224, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| batch_size=64, | |||
| device=torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 | |||
| self.resolution = resolution | |||
| self.batch_size = batch_size | |||
| self.device = device | |||
| # init model | |||
| self.net = getattr( | |||
| models, | |||
| model)(num_classes=None).eval().requires_grad_(False).to(device) | |||
| self.net.load_state_dict( | |||
| torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) | |||
| # data transforms | |||
| self.transforms = T.Compose([ | |||
| data.PadToSquare(), | |||
| T.Resize(resolution), | |||
| T.ToTensor(), | |||
| T.Normalize(mean, std) | |||
| ]) | |||
| def __call__(self, imgs, num_workers=0): | |||
| r"""imgs: Either a PIL.Image or a list of PIL.Image instances. | |||
| """ | |||
| # preprocess | |||
| if isinstance(imgs, Image.Image): | |||
| imgs = [imgs] | |||
| assert isinstance(imgs, | |||
| (tuple, list)) and isinstance(imgs[0], Image.Image) | |||
| imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) | |||
| # forward | |||
| feats = [] | |||
| for batch in imgs.split(self.batch_size, dim=0): | |||
| batch = batch.to(self.device, non_blocking=True) | |||
| feats.append(self.net(batch)) | |||
| return torch.cat(feats, dim=0) | |||
| def batch_process(self, paths): | |||
| # init dataloader | |||
| dataloader = DataLoader( | |||
| dataset=ImageFolder(paths, self.transforms), | |||
| batch_size=self.batch_size, | |||
| shuffle=False, | |||
| drop_last=False, | |||
| pin_memory=True, | |||
| num_workers=8, | |||
| prefetch_factor=2) | |||
| # forward | |||
| feats = [] | |||
| for step, batch in enumerate(dataloader, 1): | |||
| print(f'Step: {step}/{len(dataloader)}', flush=True) | |||
| batch = batch.to(self.device, non_blocking=True) | |||
| feats.append(self.net(batch)) | |||
| return torch.cat(feats) | |||
| class Classifier(object): | |||
| def __init__( | |||
| self, | |||
| model='InceptionV1', | |||
| checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth', | |||
| num_classes=1, | |||
| resolution=224, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| batch_size=64, | |||
| device=torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 | |||
| self.num_classes = num_classes | |||
| self.resolution = resolution | |||
| self.batch_size = batch_size | |||
| self.device = device | |||
| # init model | |||
| self.net = getattr(models, model)( | |||
| num_classes=num_classes).eval().requires_grad_(False).to(device) | |||
| self.net.load_state_dict( | |||
| torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) | |||
| # data transforms | |||
| self.transforms = T.Compose([ | |||
| data.PadToSquare(), | |||
| T.Resize(resolution), | |||
| T.ToTensor(), | |||
| T.Normalize(mean, std) | |||
| ]) | |||
| def __call__(self, imgs, num_workers=0): | |||
| r"""imgs: Either a PIL.Image or a list of PIL.Image instances. | |||
| """ | |||
| # preprocess | |||
| if isinstance(imgs, Image.Image): | |||
| imgs = [imgs] | |||
| assert isinstance(imgs, | |||
| (tuple, list)) and isinstance(imgs[0], Image.Image) | |||
| imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) | |||
| # forward | |||
| scores = [] | |||
| for batch in imgs.split(self.batch_size, dim=0): | |||
| batch = batch.to(self.device, non_blocking=True) | |||
| logits = self.net(batch) | |||
| scores.append(logits.sigmoid() if self.num_classes == # noqa W504 | |||
| 1 else logits.softmax(dim=1)) | |||
| return torch.cat(scores, dim=0) | |||
| class Text2Image(object): | |||
| def __init__( | |||
| self, | |||
| vqgan_dim=128, | |||
| vqgan_z_dim=256, | |||
| vqgan_dim_mult=[1, 1, 2, 2, 4], | |||
| vqgan_num_res_blocks=2, | |||
| vqgan_attn_scales=[1.0 / 16], | |||
| vqgan_codebook_size=975, | |||
| vqgan_beta=0.25, | |||
| gpt_txt_vocab_size=21128, | |||
| gpt_txt_seq_len=64, | |||
| gpt_img_seq_len=1024, | |||
| gpt_dim=1024, | |||
| gpt_num_heads=16, | |||
| gpt_num_layers=24, | |||
| vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', | |||
| gpt_checkpoint='models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth', | |||
| tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64), | |||
| batch_size=16, | |||
| device=torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 | |||
| self.tokenizer = tokenizer | |||
| self.batch_size = batch_size | |||
| self.device = device | |||
| # init VQGAN model | |||
| self.vqgan = models.VQGAN( | |||
| dim=vqgan_dim, | |||
| z_dim=vqgan_z_dim, | |||
| dim_mult=vqgan_dim_mult, | |||
| num_res_blocks=vqgan_num_res_blocks, | |||
| attn_scales=vqgan_attn_scales, | |||
| codebook_size=vqgan_codebook_size, | |||
| beta=vqgan_beta).eval().requires_grad_(False).to(device) | |||
| self.vqgan.load_state_dict( | |||
| torch.load( | |||
| DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device)) | |||
| # init GPT model | |||
| self.gpt = models.Seq2SeqGPT( | |||
| src_vocab_size=gpt_txt_vocab_size, | |||
| tar_vocab_size=vqgan_codebook_size, | |||
| src_seq_len=gpt_txt_seq_len, | |||
| tar_seq_len=gpt_img_seq_len, | |||
| dim=gpt_dim, | |||
| num_heads=gpt_num_heads, | |||
| num_layers=gpt_num_layers).eval().requires_grad_(False).to(device) | |||
| self.gpt.load_state_dict( | |||
| torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device)) | |||
| def __call__(self, | |||
| txts, | |||
| top_k=64, | |||
| top_p=None, | |||
| temperature=0.6, | |||
| use_fp16=True): | |||
| # preprocess | |||
| if isinstance(txts, str): | |||
| txts = [txts] | |||
| assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str) | |||
| txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts]) | |||
| # forward | |||
| out_imgs = [] | |||
| for batch in txt_tokens.split(self.batch_size, dim=0): | |||
| # sample | |||
| batch = batch.to(self.device, non_blocking=True) | |||
| with amp.autocast(enabled=use_fp16): | |||
| img_tokens = self.gpt.sample(batch, top_k, top_p, temperature) | |||
| # decode | |||
| imgs = self.vqgan.decode_from_tokens(img_tokens) | |||
| imgs = self._whiten_borders(imgs) | |||
| imgs = imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute( | |||
| 0, 2, 3, 1).cpu().numpy().astype(np.uint8) | |||
| imgs = [Image.fromarray(u) for u in imgs] | |||
| # append | |||
| out_imgs += imgs | |||
| return out_imgs | |||
| def _whiten_borders(self, imgs): | |||
| r"""Remove border artifacts. | |||
| """ | |||
| imgs[:, :, :18, :] = 1 | |||
| imgs[:, :, :, :18] = 1 | |||
| imgs[:, :, -18:, :] = 1 | |||
| imgs[:, :, :, -18:] = 1 | |||
| return imgs | |||
| class Sole2Shoe(object): | |||
| def __init__( | |||
| self, | |||
| vqgan_dim=128, | |||
| vqgan_z_dim=256, | |||
| vqgan_dim_mult=[1, 1, 2, 2, 4], | |||
| vqgan_num_res_blocks=2, | |||
| vqgan_attn_scales=[1.0 / 16], | |||
| vqgan_codebook_size=975, | |||
| vqgan_beta=0.25, | |||
| src_resolution=256, | |||
| tar_resolution=512, | |||
| gpt_dim=1024, | |||
| gpt_num_heads=16, | |||
| gpt_num_layers=24, | |||
| vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', | |||
| gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth', | |||
| batch_size=12, | |||
| device=torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 | |||
| self.batch_size = batch_size | |||
| self.device = device | |||
| src_seq_len = (src_resolution // 16)**2 | |||
| tar_seq_len = (tar_resolution // 16)**2 | |||
| # init VQGAN model | |||
| self.vqgan = models.VQGAN( | |||
| dim=vqgan_dim, | |||
| z_dim=vqgan_z_dim, | |||
| dim_mult=vqgan_dim_mult, | |||
| num_res_blocks=vqgan_num_res_blocks, | |||
| attn_scales=vqgan_attn_scales, | |||
| codebook_size=vqgan_codebook_size, | |||
| beta=vqgan_beta).eval().requires_grad_(False).to(device) | |||
| self.vqgan.load_state_dict( | |||
| torch.load( | |||
| DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device)) | |||
| # init GPT model | |||
| self.gpt = models.Seq2SeqGPT( | |||
| src_vocab_size=vqgan_codebook_size, | |||
| tar_vocab_size=vqgan_codebook_size, | |||
| src_seq_len=src_seq_len, | |||
| tar_seq_len=tar_seq_len, | |||
| dim=gpt_dim, | |||
| num_heads=gpt_num_heads, | |||
| num_layers=gpt_num_layers).eval().requires_grad_(False).to(device) | |||
| self.gpt.load_state_dict( | |||
| torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device)) | |||
| # data transforms | |||
| self.transforms = T.Compose([ | |||
| data.PadToSquare(), | |||
| T.Resize(src_resolution), | |||
| T.ToTensor(), | |||
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |||
| ]) | |||
| def __call__(self, | |||
| sole_imgs, | |||
| top_k=64, | |||
| top_p=None, | |||
| temperature=0.6, | |||
| use_fp16=True, | |||
| num_workers=0): | |||
| # preprocess | |||
| if isinstance(sole_imgs, Image.Image): | |||
| sole_imgs = [sole_imgs] | |||
| assert isinstance(sole_imgs, (tuple, list)) and isinstance( | |||
| sole_imgs[0], Image.Image) | |||
| sole_imgs = torch.stack( | |||
| parallel(self.transforms, sole_imgs, num_workers), dim=0) | |||
| # forward | |||
| out_imgs = [] | |||
| for batch in sole_imgs.split(self.batch_size, dim=0): | |||
| # sample | |||
| batch = batch.to(self.device) | |||
| with amp.autocast(enabled=use_fp16): | |||
| sole_tokens = self.vqgan.encode_to_tokens(batch) | |||
| shoe_tokens = self.gpt.sample(sole_tokens, top_k, top_p, | |||
| temperature) | |||
| # decode | |||
| shoe_imgs = self.vqgan.decode_from_tokens(shoe_tokens) | |||
| shoe_imgs = self._whiten_borders(shoe_imgs) | |||
| shoe_imgs = shoe_imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute( | |||
| 0, 2, 3, 1).cpu().numpy().astype(np.uint8) | |||
| shoe_imgs = [Image.fromarray(u) for u in shoe_imgs] | |||
| # append | |||
| out_imgs += shoe_imgs | |||
| return out_imgs | |||
| def _whiten_borders(self, imgs): | |||
| r"""Remove border artifacts. | |||
| """ | |||
| imgs[:, :, :18, :] = 1 | |||
| imgs[:, :, :, :18] = 1 | |||
| imgs[:, :, -18:, :] = 1 | |||
| imgs[:, :, :, -18:] = 1 | |||
| return imgs | |||
| class ImageParser(object): | |||
| def __init__( | |||
| self, | |||
| model='SPNet', | |||
| num_classes=2, | |||
| resolution=800, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| model_with_softmax=False, | |||
| checkpoint='models/spnet/sole_segmentation_211219.pth', | |||
| batch_size=16, | |||
| device=torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 | |||
| self.batch_size = batch_size | |||
| self.device = device | |||
| # init model | |||
| if checkpoint.endswith('.pt'): | |||
| self.net = torch.jit.load( | |||
| DOWNLOAD_TO_CACHE(checkpoint)).eval().to(device) | |||
| [p.requires_grad_(False) for p in self.net.parameters()] | |||
| else: | |||
| self.net = getattr(models, model)( | |||
| num_classes=num_classes, | |||
| pretrained=False).eval().requires_grad_(False).to(device) | |||
| self.net.load_state_dict( | |||
| torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) | |||
| self.softmax = (lambda x, dim: x) if model_with_softmax else F.softmax | |||
| # data transforms | |||
| self.transforms = T.Compose([ | |||
| data.PadToSquare(), | |||
| T.Resize(resolution), | |||
| T.ToTensor(), | |||
| T.Normalize(mean, std) | |||
| ]) | |||
| def __call__(self, imgs, num_workers=0): | |||
| # preprocess | |||
| if isinstance(imgs, Image.Image): | |||
| imgs = [imgs] | |||
| assert isinstance(imgs, | |||
| (tuple, list)) and isinstance(imgs[0], Image.Image) | |||
| sizes = [u.size for u in imgs] | |||
| imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) | |||
| # forward | |||
| masks = [] | |||
| for batch in imgs.split(self.batch_size, dim=0): | |||
| batch = batch.to(self.device, non_blocking=True) | |||
| masks.append(self.softmax(self.net(batch), dim=1)) | |||
| # postprocess | |||
| masks = torch.cat(masks, dim=0).unsqueeze(1) | |||
| masks = [ | |||
| F.interpolate(u, v, mode='bilinear', align_corners=False) | |||
| for u, v in zip(masks, sizes) | |||
| ] | |||
| return masks | |||
| class TextImageMatch(object): | |||
| def __init__( | |||
| self, | |||
| embed_dim=512, | |||
| image_size=224, | |||
| patch_size=32, | |||
| vision_dim=768, | |||
| vision_heads=12, | |||
| vision_layers=12, | |||
| vocab_size=21128, | |||
| text_len=77, | |||
| text_dim=512, | |||
| text_heads=8, | |||
| text_layers=12, | |||
| mean=[0.48145466, 0.4578275, 0.40821073], | |||
| std=[0.26862954, 0.26130258, 0.27577711], | |||
| checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth', | |||
| tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77), | |||
| batch_size=64, | |||
| device=torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 | |||
| self.tokenizer = tokenizer | |||
| self.batch_size = batch_size | |||
| self.device = device | |||
| # init model | |||
| self.clip = models.CLIP( | |||
| embed_dim=embed_dim, | |||
| image_size=image_size, | |||
| patch_size=patch_size, | |||
| vision_dim=vision_dim, | |||
| vision_heads=vision_heads, | |||
| vision_layers=vision_layers, | |||
| vocab_size=vocab_size, | |||
| text_len=text_len, | |||
| text_dim=text_dim, | |||
| text_heads=text_heads, | |||
| text_layers=text_layers).eval().requires_grad_(False).to(device) | |||
| self.clip.load_state_dict( | |||
| torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) | |||
| # transforms | |||
| scale_size = int(image_size * 8 / 7) | |||
| self.transforms = T.Compose([ | |||
| data.PadToSquare(), | |||
| T.Resize(scale_size), | |||
| T.CenterCrop(image_size), | |||
| T.ToTensor(), | |||
| T.Normalize(mean, std) | |||
| ]) | |||
| def __call__(self, imgs, txts, num_workers=0): | |||
| # preprocess | |||
| assert isinstance(imgs, | |||
| (tuple, list)) and isinstance(imgs[0], Image.Image) | |||
| assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str) | |||
| txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts]) | |||
| imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) | |||
| # forward | |||
| scores = [] | |||
| for img_batch, txt_batch in zip( | |||
| imgs.split(self.batch_size, dim=0), | |||
| txt_tokens.split(self.batch_size, dim=0)): | |||
| img_batch = img_batch.to(self.device) | |||
| txt_batch = txt_batch.to(self.device) | |||
| xi = F.normalize(self.clip.visual(img_batch), p=2, dim=1) | |||
| xt = F.normalize(self.clip.textual(txt_batch), p=2, dim=1) | |||
| scores.append((xi * xt).sum(dim=1)) | |||
| return torch.cat(scores, dim=0) | |||
| def taobao_feature_extractor(category='shoes', **kwargs): | |||
| r"""Pretrained taobao-search feature extractors. | |||
| """ | |||
| assert category in ['softall', 'shoes', 'bag'] | |||
| checkpoint = osp.join( | |||
| 'models/inception-v1', { | |||
| 'softall': '1214softall_10.10.0.5000', | |||
| 'shoes': '1218shoes.v9_7.140.0.1520000', | |||
| 'bag': '0926bag.v9_6.29.0.140000' | |||
| }[category]) | |||
| app = FeatureExtractor( | |||
| model='InceptionV1', | |||
| checkpoint=checkpoint, | |||
| resolution=224, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| **kwargs) | |||
| return app | |||
| def singleton_classifier(**kwargs): | |||
| r"""Pretrained classifier that finds single-object images. | |||
| Supports shoes, apparel, and bag images. | |||
| """ | |||
| app = Classifier( | |||
| model='InceptionV1', | |||
| checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth', | |||
| num_classes=1, | |||
| resolution=224, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| **kwargs) | |||
| return app | |||
| def orientation_classifier(**kwargs): | |||
| r"""Shoes orientation classifier. | |||
| """ | |||
| app = Classifier( | |||
| model='InceptionV1', | |||
| checkpoint='models/classifier/shoes-oriendetect-20211026.pth', | |||
| num_classes=1, | |||
| resolution=224, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| **kwargs) | |||
| return app | |||
| def fashion_text2image(**kwargs): | |||
| r"""Fashion text-to-image generator. | |||
| Supports shoe and apparel image generation. | |||
| """ | |||
| app = Text2Image( | |||
| vqgan_dim=128, | |||
| vqgan_z_dim=256, | |||
| vqgan_dim_mult=[1, 1, 2, 2, 4], | |||
| vqgan_num_res_blocks=2, | |||
| vqgan_attn_scales=[1.0 / 16], | |||
| vqgan_codebook_size=975, | |||
| vqgan_beta=0.25, | |||
| gpt_txt_vocab_size=21128, | |||
| gpt_txt_seq_len=64, | |||
| gpt_img_seq_len=1024, | |||
| gpt_dim=1024, | |||
| gpt_num_heads=16, | |||
| gpt_num_layers=24, | |||
| vqgan_checkpoint= # noqa E251 | |||
| 'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', | |||
| gpt_checkpoint= # noqa E251 | |||
| 'models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth', | |||
| tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64), | |||
| **kwargs) | |||
| return app | |||
| def mindalle_text2image(**kwargs): | |||
| r"""Pretrained text2image generator with weights copied from minDALL-E. | |||
| """ | |||
| app = Text2Image( | |||
| vqgan_dim=128, | |||
| vqgan_z_dim=256, | |||
| vqgan_dim_mult=[1, 1, 2, 2, 4], | |||
| vqgan_num_res_blocks=2, | |||
| vqgan_attn_scales=[1.0 / 16], | |||
| vqgan_codebook_size=16384, | |||
| vqgan_beta=0.25, | |||
| gpt_txt_vocab_size=16384, | |||
| gpt_txt_seq_len=64, | |||
| gpt_img_seq_len=256, | |||
| gpt_dim=1536, | |||
| gpt_num_heads=24, | |||
| gpt_num_layers=42, | |||
| vqgan_checkpoint='models/minDALLE/1.3B_vqgan.pth', | |||
| gpt_checkpoint='models/minDALLE/1.3B_gpt.pth', | |||
| tokenizer=data.BPETokenizer(length=64), | |||
| **kwargs) | |||
| return app | |||
| def sole2shoe(**kwargs): | |||
| app = Sole2Shoe( | |||
| vqgan_dim=128, | |||
| vqgan_z_dim=256, | |||
| vqgan_dim_mult=[1, 1, 2, 2, 4], | |||
| vqgan_num_res_blocks=2, | |||
| vqgan_attn_scales=[1.0 / 16], | |||
| vqgan_codebook_size=975, | |||
| vqgan_beta=0.25, | |||
| src_resolution=256, | |||
| tar_resolution=512, | |||
| gpt_dim=1024, | |||
| gpt_num_heads=16, | |||
| gpt_num_layers=24, | |||
| vqgan_checkpoint= # noqa E251 | |||
| 'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', | |||
| gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth', | |||
| **kwargs) | |||
| return app | |||
| def sole_parser(**kwargs): | |||
| app = ImageParser( | |||
| model='SPNet', | |||
| num_classes=2, | |||
| resolution=800, | |||
| mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225], | |||
| model_with_softmax=False, | |||
| checkpoint='models/spnet/sole_segmentation_211219.pth', | |||
| **kwargs) | |||
| return app | |||
| def sod_foreground_parser(**kwargs): | |||
| app = ImageParser( | |||
| model=None, | |||
| num_classes=None, | |||
| resolution=448, | |||
| mean=[0.488431, 0.466275, 0.403686], | |||
| std=[0.222627, 0.21949, 0.22549], | |||
| model_with_softmax=True, | |||
| checkpoint='models/semseg/sod_model_20201228.pt', | |||
| **kwargs) | |||
| return app | |||
| def fashion_text_image_match(**kwargs): | |||
| app = TextImageMatch( | |||
| embed_dim=512, | |||
| image_size=224, | |||
| patch_size=32, | |||
| vision_dim=768, | |||
| vision_heads=12, | |||
| vision_layers=12, | |||
| vocab_size=21128, | |||
| text_len=77, | |||
| text_dim=512, | |||
| text_heads=8, | |||
| text_layers=12, | |||
| mean=[0.48145466, 0.4578275, 0.40821073], | |||
| std=[0.26862954, 0.26130258, 0.27577711], | |||
| checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth', | |||
| tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77), | |||
| **kwargs) | |||
| return app | |||
| @@ -0,0 +1,598 @@ | |||
| import math | |||
| import torch | |||
| from .losses import discretized_gaussian_log_likelihood, kl_divergence | |||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | |||
| def _i(tensor, t, x): | |||
| r"""Index tensor using t and format the output according to x. | |||
| """ | |||
| shape = (x.size(0), ) + (1, ) * (x.ndim - 1) | |||
| return tensor[t].view(shape).to(x) | |||
| def beta_schedule(schedule, | |||
| num_timesteps=1000, | |||
| init_beta=None, | |||
| last_beta=None): | |||
| if schedule == 'linear': | |||
| scale = 1000.0 / num_timesteps | |||
| init_beta = init_beta or scale * 0.0001 | |||
| last_beta = last_beta or scale * 0.02 | |||
| return torch.linspace( | |||
| init_beta, last_beta, num_timesteps, dtype=torch.float64) | |||
| elif schedule == 'quadratic': | |||
| init_beta = init_beta or 0.0015 | |||
| last_beta = last_beta or 0.0195 | |||
| return torch.linspace( | |||
| init_beta**0.5, last_beta**0.5, num_timesteps, | |||
| dtype=torch.float64)**2 | |||
| elif schedule == 'cosine': | |||
| betas = [] | |||
| for step in range(num_timesteps): | |||
| t1 = step / num_timesteps | |||
| t2 = (step + 1) / num_timesteps | |||
| # fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 | |||
| def fn(u): | |||
| return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 | |||
| betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) | |||
| return torch.tensor(betas, dtype=torch.float64) | |||
| else: | |||
| raise ValueError(f'Unsupported schedule: {schedule}') | |||
| class GaussianDiffusion(object): | |||
| def __init__(self, | |||
| betas, | |||
| mean_type='eps', | |||
| var_type='learned_range', | |||
| loss_type='mse', | |||
| rescale_timesteps=False): | |||
| # check input | |||
| if not isinstance(betas, torch.DoubleTensor): | |||
| betas = torch.tensor(betas, dtype=torch.float64) | |||
| assert min(betas) > 0 and max(betas) <= 1 | |||
| assert mean_type in ['x0', 'x_{t-1}', 'eps'] | |||
| assert var_type in [ | |||
| 'learned', 'learned_range', 'fixed_large', 'fixed_small' | |||
| ] | |||
| assert loss_type in [ | |||
| 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' | |||
| ] | |||
| self.betas = betas | |||
| self.num_timesteps = len(betas) | |||
| self.mean_type = mean_type | |||
| self.var_type = var_type | |||
| self.loss_type = loss_type | |||
| self.rescale_timesteps = rescale_timesteps | |||
| # alphas | |||
| alphas = 1 - self.betas | |||
| self.alphas_cumprod = torch.cumprod(alphas, dim=0) | |||
| self.alphas_cumprod_prev = torch.cat( | |||
| [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) | |||
| self.alphas_cumprod_next = torch.cat( | |||
| [self.alphas_cumprod[1:], | |||
| alphas.new_zeros([1])]) | |||
| # q(x_t | x_{t-1}) | |||
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |||
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 | |||
| - self.alphas_cumprod) | |||
| self.log_one_minus_alphas_cumprod = torch.log(1.0 | |||
| - self.alphas_cumprod) | |||
| self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) | |||
| self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod | |||
| - 1) | |||
| # q(x_{t-1} | x_t, x_0) | |||
| self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( | |||
| 1.0 - self.alphas_cumprod) | |||
| self.posterior_log_variance_clipped = torch.log( | |||
| self.posterior_variance.clamp(1e-20)) | |||
| self.posterior_mean_coef1 = betas * torch.sqrt( | |||
| self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |||
| self.posterior_mean_coef2 = ( | |||
| 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( | |||
| 1.0 - self.alphas_cumprod) | |||
| def q_sample(self, x0, t, noise=None): | |||
| r"""Sample from q(x_t | x_0). | |||
| """ | |||
| noise = torch.randn_like(x0) if noise is None else noise | |||
| return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( | |||
| self.sqrt_one_minus_alphas_cumprod, t, x0) * noise | |||
| def q_mean_variance(self, x0, t): | |||
| r"""Distribution of q(x_t | x_0). | |||
| """ | |||
| mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 | |||
| var = _i(1.0 - self.alphas_cumprod, t, x0) | |||
| log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) | |||
| return mu, var, log_var | |||
| def q_posterior_mean_variance(self, x0, xt, t): | |||
| r"""Distribution of q(x_{t-1} | x_t, x_0). | |||
| """ | |||
| mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( | |||
| self.posterior_mean_coef2, t, xt) * xt | |||
| var = _i(self.posterior_variance, t, xt) | |||
| log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||
| return mu, var, log_var | |||
| @torch.no_grad() | |||
| def p_sample(self, | |||
| xt, | |||
| t, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None): | |||
| r"""Sample from p(x_{t-1} | x_t). | |||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||
| """ | |||
| # predict distribution of p(x_{t-1} | x_t) | |||
| mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||
| clamp, percentile, | |||
| guide_scale) | |||
| # random sample (with optional conditional function) | |||
| noise = torch.randn_like(xt) | |||
| # no noise when t == 0 | |||
| mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||
| if condition_fn is not None: | |||
| grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) | |||
| mu = mu.float() + var * grad.float() | |||
| xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise | |||
| return xt_1, x0 | |||
| @torch.no_grad() | |||
| def p_sample_loop(self, | |||
| noise, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None): | |||
| r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). | |||
| """ | |||
| # prepare input | |||
| b, c, h, w = noise.size() | |||
| xt = noise | |||
| # diffusion process | |||
| for step in torch.arange(self.num_timesteps).flip(0): | |||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||
| xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, | |||
| percentile, condition_fn, guide_scale) | |||
| return xt | |||
| def p_mean_variance(self, | |||
| xt, | |||
| t, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| guide_scale=None): | |||
| r"""Distribution of p(x_{t-1} | x_t). | |||
| """ | |||
| # predict distribution | |||
| if guide_scale is None: | |||
| out = model(xt, self._scale_timesteps(t), **model_kwargs) | |||
| else: | |||
| # classifier-free guidance | |||
| # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) | |||
| assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 | |||
| assert self.mean_type == 'eps' | |||
| y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) | |||
| u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) | |||
| out = torch.cat( | |||
| [ | |||
| u_out[:, :3] + guide_scale * # noqa W504 | |||
| (y_out[:, :3] - u_out[:, :3]), | |||
| y_out[:, 3:] | |||
| ], | |||
| dim=1) | |||
| # compute variance | |||
| if self.var_type == 'learned': | |||
| out, log_var = out.chunk(2, dim=1) | |||
| var = torch.exp(log_var) | |||
| elif self.var_type == 'learned_range': | |||
| out, fraction = out.chunk(2, dim=1) | |||
| min_log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||
| max_log_var = _i(torch.log(self.betas), t, xt) | |||
| fraction = (fraction + 1) / 2.0 | |||
| log_var = fraction * max_log_var + (1 - fraction) * min_log_var | |||
| var = torch.exp(log_var) | |||
| elif self.var_type == 'fixed_large': | |||
| var = _i( | |||
| torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, | |||
| xt) | |||
| log_var = torch.log(var) | |||
| elif self.var_type == 'fixed_small': | |||
| var = _i(self.posterior_variance, t, xt) | |||
| log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||
| # compute mean and x0 | |||
| if self.mean_type == 'x_{t-1}': | |||
| mu = out # x_{t-1} | |||
| x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( | |||
| self.posterior_mean_coef2 / self.posterior_mean_coef1, t, | |||
| xt) * xt | |||
| elif self.mean_type == 'x0': | |||
| x0 = out | |||
| mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) | |||
| elif self.mean_type == 'eps': | |||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * out | |||
| mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) | |||
| # restrict the range of x0 | |||
| if percentile is not None: | |||
| assert percentile > 0 and percentile <= 1 # e.g., 0.995 | |||
| s = torch.quantile( | |||
| x0.flatten(1).abs(), percentile, | |||
| dim=1).clamp_(1.0).view(-1, 1, 1, 1) | |||
| x0 = torch.min(s, torch.max(-s, x0)) / s | |||
| elif clamp is not None: | |||
| x0 = x0.clamp(-clamp, clamp) | |||
| return mu, var, log_var, x0 | |||
| @torch.no_grad() | |||
| def ddim_sample(self, | |||
| xt, | |||
| t, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None, | |||
| ddim_timesteps=20, | |||
| eta=0.0): | |||
| stride = self.num_timesteps // ddim_timesteps | |||
| # predict distribution of p(x_{t-1} | x_t) | |||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, | |||
| percentile, guide_scale) | |||
| if condition_fn is not None: | |||
| # x0 -> eps | |||
| alpha = _i(self.alphas_cumprod, t, xt) | |||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||
| eps = eps - (1 - alpha).sqrt() * condition_fn( | |||
| xt, self._scale_timesteps(t), **model_kwargs) | |||
| # eps -> x0 | |||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||
| # derive variables | |||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||
| alphas = _i(self.alphas_cumprod, t, xt) | |||
| alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) | |||
| sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504 | |||
| (1 - alphas) * # noqa W504 | |||
| (1 - alphas / alphas_prev)) | |||
| # random sample | |||
| noise = torch.randn_like(xt) | |||
| direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps | |||
| mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||
| xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise | |||
| return xt_1, x0 | |||
| @torch.no_grad() | |||
| def ddim_sample_loop(self, | |||
| noise, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None, | |||
| ddim_timesteps=20, | |||
| eta=0.0): | |||
| # prepare input | |||
| b, c, h, w = noise.size() | |||
| xt = noise | |||
| # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) | |||
| steps = (1 + torch.arange(0, self.num_timesteps, | |||
| self.num_timesteps // ddim_timesteps)).clamp( | |||
| 0, self.num_timesteps - 1).flip(0) | |||
| for step in steps: | |||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||
| xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, | |||
| percentile, condition_fn, guide_scale, | |||
| ddim_timesteps, eta) | |||
| return xt | |||
| @torch.no_grad() | |||
| def ddim_reverse_sample(self, | |||
| xt, | |||
| t, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| guide_scale=None, | |||
| ddim_timesteps=20): | |||
| r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). | |||
| """ | |||
| stride = self.num_timesteps // ddim_timesteps | |||
| # predict distribution of p(x_{t-1} | x_t) | |||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, | |||
| percentile, guide_scale) | |||
| # derive variables | |||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||
| alphas_next = _i( | |||
| torch.cat( | |||
| [self.alphas_cumprod, | |||
| self.alphas_cumprod.new_zeros([1])]), | |||
| (t + stride).clamp(0, self.num_timesteps), xt) | |||
| # reverse sample | |||
| mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps | |||
| return mu, x0 | |||
| @torch.no_grad() | |||
| def ddim_reverse_sample_loop(self, | |||
| x0, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| guide_scale=None, | |||
| ddim_timesteps=20): | |||
| # prepare input | |||
| b, c, h, w = x0.size() | |||
| xt = x0 | |||
| # reconstruction steps | |||
| steps = torch.arange(0, self.num_timesteps, | |||
| self.num_timesteps // ddim_timesteps) | |||
| for step in steps: | |||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||
| xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, | |||
| percentile, guide_scale, | |||
| ddim_timesteps) | |||
| return xt | |||
| @torch.no_grad() | |||
| def plms_sample(self, | |||
| xt, | |||
| t, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None, | |||
| plms_timesteps=20): | |||
| r"""Sample from p(x_{t-1} | x_t) using PLMS. | |||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||
| """ | |||
| stride = self.num_timesteps // plms_timesteps | |||
| # function for compute eps | |||
| def compute_eps(xt, t): | |||
| # predict distribution of p(x_{t-1} | x_t) | |||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||
| clamp, percentile, guide_scale) | |||
| # condition | |||
| if condition_fn is not None: | |||
| # x0 -> eps | |||
| alpha = _i(self.alphas_cumprod, t, xt) | |||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt | |||
| - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||
| eps = eps - (1 - alpha).sqrt() * condition_fn( | |||
| xt, self._scale_timesteps(t), **model_kwargs) | |||
| # eps -> x0 | |||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||
| # derive eps | |||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||
| return eps | |||
| # function for compute x_0 and x_{t-1} | |||
| def compute_x0(eps, t): | |||
| # eps -> x0 | |||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||
| # deterministic sample | |||
| alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) | |||
| direction = torch.sqrt(1 - alphas_prev) * eps | |||
| # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||
| xt_1 = torch.sqrt(alphas_prev) * x0 + direction | |||
| return xt_1, x0 | |||
| # PLMS sample | |||
| eps = compute_eps(xt, t) | |||
| if len(eps_cache) == 0: | |||
| # 2nd order pseudo improved Euler | |||
| xt_1, x0 = compute_x0(eps, t) | |||
| eps_next = compute_eps(xt_1, (t - stride).clamp(0)) | |||
| eps_prime = (eps + eps_next) / 2.0 | |||
| elif len(eps_cache) == 1: | |||
| # 2nd order pseudo linear multistep (Adams-Bashforth) | |||
| eps_prime = (3 * eps - eps_cache[-1]) / 2.0 | |||
| elif len(eps_cache) == 2: | |||
| # 3nd order pseudo linear multistep (Adams-Bashforth) | |||
| eps_prime = (23 * eps - 16 * eps_cache[-1] | |||
| + 5 * eps_cache[-2]) / 12.0 | |||
| elif len(eps_cache) >= 3: | |||
| # 4nd order pseudo linear multistep (Adams-Bashforth) | |||
| eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] | |||
| - 9 * eps_cache[-3]) / 24.0 | |||
| xt_1, x0 = compute_x0(eps_prime, t) | |||
| return xt_1, x0, eps | |||
| @torch.no_grad() | |||
| def plms_sample_loop(self, | |||
| noise, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None, | |||
| plms_timesteps=20): | |||
| # prepare input | |||
| b, c, h, w = noise.size() | |||
| xt = noise | |||
| # diffusion process | |||
| steps = (1 + torch.arange(0, self.num_timesteps, | |||
| self.num_timesteps // plms_timesteps)).clamp( | |||
| 0, self.num_timesteps - 1).flip(0) | |||
| eps_cache = [] | |||
| for step in steps: | |||
| # PLMS sampling step | |||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||
| xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, | |||
| percentile, condition_fn, | |||
| guide_scale, plms_timesteps, | |||
| eps_cache) | |||
| # update eps cache | |||
| eps_cache.append(eps) | |||
| if len(eps_cache) >= 4: | |||
| eps_cache.pop(0) | |||
| return xt | |||
| def loss(self, x0, t, model, model_kwargs={}, noise=None): | |||
| noise = torch.randn_like(x0) if noise is None else noise | |||
| xt = self.q_sample(x0, t, noise=noise) | |||
| # compute loss | |||
| if self.loss_type in ['kl', 'rescaled_kl']: | |||
| loss, _ = self.variational_lower_bound(x0, xt, t, model, | |||
| model_kwargs) | |||
| if self.loss_type == 'rescaled_kl': | |||
| loss = loss * self.num_timesteps | |||
| elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: | |||
| out = model(xt, self._scale_timesteps(t), **model_kwargs) | |||
| # VLB for variation | |||
| loss_vlb = 0.0 | |||
| if self.var_type in ['learned', 'learned_range']: | |||
| out, var = out.chunk(2, dim=1) | |||
| frozen = torch.cat([ | |||
| out.detach(), var | |||
| ], dim=1) # learn var without affecting the prediction of mean | |||
| loss_vlb, _ = self.variational_lower_bound( | |||
| x0, xt, t, model=lambda *args, **kwargs: frozen) | |||
| if self.loss_type.startswith('rescaled_'): | |||
| loss_vlb = loss_vlb * self.num_timesteps / 1000.0 | |||
| # MSE/L1 for x0/eps | |||
| target = { | |||
| 'eps': noise, | |||
| 'x0': x0, | |||
| 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] | |||
| }[self.mean_type] | |||
| loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 | |||
| ).abs().flatten(1).mean(dim=1) | |||
| # total loss | |||
| loss = loss + loss_vlb | |||
| return loss | |||
| def variational_lower_bound(self, | |||
| x0, | |||
| xt, | |||
| t, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None): | |||
| # compute groundtruth and predicted distributions | |||
| mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) | |||
| mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||
| clamp, percentile) | |||
| # compute KL loss | |||
| kl = kl_divergence(mu1, log_var1, mu2, log_var2) | |||
| kl = kl.flatten(1).mean(dim=1) / math.log(2.0) | |||
| # compute discretized NLL loss (for p(x0 | x1) only) | |||
| nll = -discretized_gaussian_log_likelihood( | |||
| x0, mean=mu2, log_scale=0.5 * log_var2) | |||
| nll = nll.flatten(1).mean(dim=1) / math.log(2.0) | |||
| # NLL for p(x0 | x1) and KL otherwise | |||
| vlb = torch.where(t == 0, nll, kl) | |||
| return vlb, x0 | |||
| @torch.no_grad() | |||
| def variational_lower_bound_loop(self, | |||
| x0, | |||
| model, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None): | |||
| r"""Compute the entire variational lower bound, measured in bits-per-dim. | |||
| """ | |||
| # prepare input and output | |||
| b, c, h, w = x0.size() | |||
| metrics = {'vlb': [], 'mse': [], 'x0_mse': []} | |||
| # loop | |||
| for step in torch.arange(self.num_timesteps).flip(0): | |||
| # compute VLB | |||
| t = torch.full((b, ), step, dtype=torch.long, device=x0.device) | |||
| noise = torch.randn_like(x0) | |||
| xt = self.q_sample(x0, t, noise) | |||
| vlb, pred_x0 = self.variational_lower_bound( | |||
| x0, xt, t, model, model_kwargs, clamp, percentile) | |||
| # predict eps from x0 | |||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||
| # collect metrics | |||
| metrics['vlb'].append(vlb) | |||
| metrics['x0_mse'].append( | |||
| (pred_x0 - x0).square().flatten(1).mean(dim=1)) | |||
| metrics['mse'].append( | |||
| (eps - noise).square().flatten(1).mean(dim=1)) | |||
| metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} | |||
| # compute the prior KL term for VLB, measured in bits-per-dim | |||
| mu, _, log_var = self.q_mean_variance(x0, t) | |||
| kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), | |||
| torch.zeros_like(log_var)) | |||
| kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) | |||
| # update metrics | |||
| metrics['prior_bits_per_dim'] = kl_prior | |||
| metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior | |||
| return metrics | |||
| def _scale_timesteps(self, t): | |||
| if self.rescale_timesteps: | |||
| return t.float() * 1000.0 / self.num_timesteps | |||
| return t | |||
| @@ -0,0 +1,35 @@ | |||
| import math | |||
| import torch | |||
| __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] | |||
| def kl_divergence(mu1, logvar1, mu2, logvar2): | |||
| return 0.5 * ( | |||
| -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504 | |||
| ((mu1 - mu2)**2) * torch.exp(-logvar2)) | |||
| def standard_normal_cdf(x): | |||
| r"""A fast approximation of the cumulative distribution function of the standard normal. | |||
| """ | |||
| return 0.5 * (1.0 + torch.tanh( | |||
| math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |||
| def discretized_gaussian_log_likelihood(x0, mean, log_scale): | |||
| assert x0.shape == mean.shape == log_scale.shape | |||
| cx = x0 - mean | |||
| inv_stdv = torch.exp(-log_scale) | |||
| cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) | |||
| cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) | |||
| log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |||
| log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |||
| cdf_delta = cdf_plus - cdf_min | |||
| log_probs = torch.where( | |||
| x0 < -0.999, log_cdf_plus, | |||
| torch.where(x0 > 0.999, log_one_minus_cdf_min, | |||
| torch.log(cdf_delta.clamp(min=1e-12)))) | |||
| assert log_probs.shape == x0.shape | |||
| return log_probs | |||
| @@ -0,0 +1,126 @@ | |||
| import numpy as np | |||
| import scipy.linalg as linalg | |||
| import torch | |||
| __all__ = [ | |||
| 'get_fid_net', 'get_is_net', 'compute_fid', 'compute_prdc', 'compute_is' | |||
| ] | |||
| def get_fid_net(resize_input=True, normalize_input=True): | |||
| r"""InceptionV3 network for the evaluation of Fréchet Inception Distance (FID). | |||
| Args: | |||
| resize_input: whether or not to resize the input to (299, 299). | |||
| normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1). | |||
| """ | |||
| from artist.models import InceptionV3 | |||
| return InceptionV3( | |||
| output_blocks=(3, ), | |||
| resize_input=resize_input, | |||
| normalize_input=normalize_input, | |||
| requires_grad=False, | |||
| use_fid_inception=True).eval().requires_grad_(False) | |||
| def get_is_net(resize_input=True, normalize_input=True): | |||
| r"""InceptionV3 network for the evaluation of Inception Score (IS). | |||
| Args: | |||
| resize_input: whether or not to resize the input to (299, 299). | |||
| normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1). | |||
| """ | |||
| from artist.models import InceptionV3 | |||
| return InceptionV3( | |||
| output_blocks=(4, ), | |||
| resize_input=resize_input, | |||
| normalize_input=normalize_input, | |||
| requires_grad=False, | |||
| use_fid_inception=False).eval().requires_grad_(False) | |||
| @torch.no_grad() | |||
| def compute_fid(real_feats, fake_feats, eps=1e-6): | |||
| r"""Compute Fréchet Inception Distance (FID). | |||
| Args: | |||
| real_feats: [N, C]. | |||
| fake_feats: [N, C]. | |||
| """ | |||
| # check inputs | |||
| if isinstance(real_feats, torch.Tensor): | |||
| real_feats = real_feats.cpu().numpy().astype(np.float_) | |||
| if isinstance(fake_feats, torch.Tensor): | |||
| fake_feats = fake_feats.cpu().numpy().astype(np.float_) | |||
| # real statistics | |||
| mu1 = np.mean(real_feats, axis=0) | |||
| sigma1 = np.cov(real_feats, rowvar=False) | |||
| # fake statistics | |||
| mu2 = np.mean(fake_feats, axis=0) | |||
| sigma2 = np.cov(fake_feats, rowvar=False) | |||
| # compute covmean | |||
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |||
| if not np.isfinite(covmean).all(): | |||
| print( | |||
| f'FID calculation produces singular product; adding {eps} to diagonal of cov', | |||
| flush=True) | |||
| offset = np.eye(sigma1.shape[0]) * eps | |||
| covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |||
| # numerical error might give slight imaginary component | |||
| if np.iscomplexobj(covmean): | |||
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |||
| m = np.max(np.abs(covmean.imag)) | |||
| raise ValueError('Imaginary component {}'.format(m)) | |||
| covmean = covmean.real | |||
| # compute Fréchet distance | |||
| diff = mu1 - mu2 | |||
| fid = diff.dot(diff) + np.trace(sigma1) + np.trace( | |||
| sigma2) - 2 * np.trace(covmean) | |||
| return fid.item() | |||
| @torch.no_grad() | |||
| def compute_prdc(real_feats, fake_feats, knn=5): | |||
| r"""Compute precision, recall, density, and coverage given two manifolds. | |||
| Args: | |||
| real_feats: [N, C]. | |||
| fake_feats: [N, C]. | |||
| knn: the number of nearest neighbors to consider. | |||
| """ | |||
| # distances | |||
| real_kth = -(-torch.cdist(real_feats, real_feats)).topk( | |||
| k=knn, dim=1)[0][:, -1] | |||
| fake_kth = -(-torch.cdist(fake_feats, fake_feats)).topk( | |||
| k=knn, dim=1)[0][:, -1] | |||
| dists = torch.cdist(real_feats, fake_feats) | |||
| # metrics | |||
| precision = (dists < real_kth.unsqueeze(1)).any( | |||
| dim=0).float().mean().item() | |||
| recall = (dists < fake_kth.unsqueeze(0)).any(dim=1).float().mean().item() | |||
| density = (dists < real_kth.unsqueeze(1)).float().sum( | |||
| dim=0).mean().item() / knn | |||
| coverage = (dists.min(dim=1)[0] < real_kth).float().mean().item() | |||
| return precision, recall, density, coverage | |||
| @torch.no_grad() | |||
| def compute_is(logits, num_splits=10): | |||
| preds = logits.softmax(dim=1).cpu().numpy() | |||
| split_scores = [] | |||
| for k in range(num_splits): | |||
| part = preds[k * (len(logits) // num_splits):(k + 1) | |||
| * (len(logits) // num_splits), :] | |||
| py = np.mean(part, axis=0) | |||
| scores = [] | |||
| for i in range(part.shape[0]): | |||
| pyx = part[i, :] | |||
| scores.append(entropy(pyx, py)) | |||
| split_scores.append(np.exp(np.mean(scores))) | |||
| return np.mean(split_scores), np.std(split_scores) | |||
| @@ -0,0 +1,220 @@ | |||
| import colorsys | |||
| import random | |||
| __all__ = ['RandomColor', 'rand_color'] | |||
| COLORMAP = { | |||
| 'blue': { | |||
| 'hue_range': [179, 257], | |||
| 'lower_bounds': [[20, 100], [30, 86], [40, 80], [50, 74], [60, 60], | |||
| [70, 52], [80, 44], [90, 39], [100, 35]] | |||
| }, | |||
| 'green': { | |||
| 'hue_range': [63, 178], | |||
| 'lower_bounds': [[30, 100], [40, 90], [50, 85], [60, 81], [70, 74], | |||
| [80, 64], [90, 50], [100, 40]] | |||
| }, | |||
| 'monochrome': { | |||
| 'hue_range': [0, 0], | |||
| 'lower_bounds': [[0, 0], [100, 0]] | |||
| }, | |||
| 'orange': { | |||
| 'hue_range': [19, 46], | |||
| 'lower_bounds': [[20, 100], [30, 93], [40, 88], [50, 86], [60, 85], | |||
| [70, 70], [100, 70]] | |||
| }, | |||
| 'pink': { | |||
| 'hue_range': [283, 334], | |||
| 'lower_bounds': [[20, 100], [30, 90], [40, 86], [60, 84], [80, 80], | |||
| [90, 75], [100, 73]] | |||
| }, | |||
| 'purple': { | |||
| 'hue_range': [258, 282], | |||
| 'lower_bounds': [[20, 100], [30, 87], [40, 79], [50, 70], [60, 65], | |||
| [70, 59], [80, 52], [90, 45], [100, 42]] | |||
| }, | |||
| 'red': { | |||
| 'hue_range': [-26, 18], | |||
| 'lower_bounds': [[20, 100], [30, 92], [40, 89], [50, 85], [60, 78], | |||
| [70, 70], [80, 60], [90, 55], [100, 50]] | |||
| }, | |||
| 'yellow': { | |||
| 'hue_range': [47, 62], | |||
| 'lower_bounds': [[25, 100], [40, 94], [50, 89], [60, 86], [70, 84], | |||
| [80, 82], [90, 80], [100, 75]] | |||
| } | |||
| } | |||
| class RandomColor(object): | |||
| def __init__(self, seed=None): | |||
| self.colormap = COLORMAP | |||
| self.random = random.Random(seed) | |||
| for color_name, color_attrs in self.colormap.items(): | |||
| lower_bounds = color_attrs['lower_bounds'] | |||
| s_min = lower_bounds[0][0] | |||
| s_max = lower_bounds[len(lower_bounds) - 1][0] | |||
| b_min = lower_bounds[len(lower_bounds) - 1][1] | |||
| b_max = lower_bounds[0][1] | |||
| self.colormap[color_name]['saturation_range'] = [s_min, s_max] | |||
| self.colormap[color_name]['brightness_range'] = [b_min, b_max] | |||
| def generate(self, hue=None, luminosity=None, count=1, format_='hex'): | |||
| colors = [] | |||
| for _ in range(count): | |||
| # First we pick a hue (H) | |||
| H = self.pick_hue(hue) | |||
| # Then use H to determine saturation (S) | |||
| S = self.pick_saturation(H, hue, luminosity) | |||
| # Then use S and H to determine brightness (B). | |||
| B = self.pick_brightness(H, S, luminosity) | |||
| # Then we return the HSB color in the desired format | |||
| colors.append(self.set_format([H, S, B], format_)) | |||
| return colors | |||
| def pick_hue(self, hue): | |||
| hue_range = self.get_hue_range(hue) | |||
| hue = self.random_within(hue_range) | |||
| # Instead of storing red as two seperate ranges, | |||
| # we group them, using negative numbers | |||
| if (hue < 0): | |||
| hue += 360 | |||
| return hue | |||
| def pick_saturation(self, hue, hue_name, luminosity): | |||
| if luminosity == 'random': | |||
| return self.random_within([0, 100]) | |||
| if hue_name == 'monochrome': | |||
| return 0 | |||
| saturation_range = self.get_saturation_range(hue) | |||
| s_min = saturation_range[0] | |||
| s_max = saturation_range[1] | |||
| if luminosity == 'bright': | |||
| s_min = 55 | |||
| elif luminosity == 'dark': | |||
| s_min = s_max - 10 | |||
| elif luminosity == 'light': | |||
| s_max = 55 | |||
| return self.random_within([s_min, s_max]) | |||
| def pick_brightness(self, H, S, luminosity): | |||
| b_min = self.get_minimum_brightness(H, S) | |||
| b_max = 100 | |||
| if luminosity == 'dark': | |||
| b_max = b_min + 20 | |||
| elif luminosity == 'light': | |||
| b_min = (b_max + b_min) / 2 | |||
| elif luminosity == 'random': | |||
| b_min = 0 | |||
| b_max = 100 | |||
| return self.random_within([b_min, b_max]) | |||
| def set_format(self, hsv, format_): | |||
| if 'hsv' in format_: | |||
| color = hsv | |||
| elif 'rgb' in format_: | |||
| color = self.hsv_to_rgb(hsv) | |||
| elif 'hex' in format_: | |||
| r, g, b = self.hsv_to_rgb(hsv) | |||
| return '#%02x%02x%02x' % (r, g, b) | |||
| else: | |||
| return 'unrecognized format' | |||
| if 'Array' in format_ or format_ == 'hex': | |||
| return color | |||
| else: | |||
| prefix = format_[:3] | |||
| color_values = [str(x) for x in color] | |||
| return '%s(%s)' % (prefix, ', '.join(color_values)) | |||
| def get_minimum_brightness(self, H, S): | |||
| lower_bounds = self.get_color_info(H)['lower_bounds'] | |||
| for i in range(len(lower_bounds) - 1): | |||
| s1 = lower_bounds[i][0] | |||
| v1 = lower_bounds[i][1] | |||
| s2 = lower_bounds[i + 1][0] | |||
| v2 = lower_bounds[i + 1][1] | |||
| if s1 <= S <= s2: | |||
| m = (v2 - v1) / (s2 - s1) | |||
| b = v1 - m * s1 | |||
| return m * S + b | |||
| return 0 | |||
| def get_hue_range(self, color_input): | |||
| if color_input and color_input.isdigit(): | |||
| number = int(color_input) | |||
| if 0 < number < 360: | |||
| return [number, number] | |||
| elif color_input and color_input in self.colormap: | |||
| color = self.colormap[color_input] | |||
| if 'hue_range' in color: | |||
| return color['hue_range'] | |||
| else: | |||
| return [0, 360] | |||
| def get_saturation_range(self, hue): | |||
| return self.get_color_info(hue)['saturation_range'] | |||
| def get_color_info(self, hue): | |||
| # Maps red colors to make picking hue easier | |||
| if 334 <= hue <= 360: | |||
| hue -= 360 | |||
| for color_name, color in self.colormap.items(): | |||
| if color['hue_range'] and color['hue_range'][0] <= hue <= color[ | |||
| 'hue_range'][1]: | |||
| return self.colormap[color_name] | |||
| # this should probably raise an exception | |||
| return 'Color not found' | |||
| def random_within(self, r): | |||
| return self.random.randint(int(r[0]), int(r[1])) | |||
| @classmethod | |||
| def hsv_to_rgb(cls, hsv): | |||
| h, s, v = hsv | |||
| h = 1 if h == 0 else h | |||
| h = 359 if h == 360 else h | |||
| h = float(h) / 360 | |||
| s = float(s) / 100 | |||
| v = float(v) / 100 | |||
| rgb = colorsys.hsv_to_rgb(h, s, v) | |||
| return [int(c * 255) for c in rgb] | |||
| def rand_color(): | |||
| generator = RandomColor() | |||
| hue = random.choice(list(COLORMAP.keys())) | |||
| color = generator.generate(hue=hue, count=1, format_='rgb')[0] | |||
| color = color[color.find('(') + 1:color.find(')')] | |||
| color = tuple([int(u) for u in color.split(',')]) | |||
| return color | |||
| @@ -0,0 +1,79 @@ | |||
| import cv2 | |||
| import numpy as np | |||
| __all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop'] | |||
| def make_irregular_mask(w, | |||
| h, | |||
| max_angle=4, | |||
| max_length=200, | |||
| max_width=100, | |||
| min_strokes=1, | |||
| max_strokes=5, | |||
| mode='line'): | |||
| # initialize mask | |||
| assert mode in ['line', 'circle', 'square'] | |||
| mask = np.zeros((h, w), np.float32) | |||
| # draw strokes | |||
| num_strokes = np.random.randint(min_strokes, max_strokes + 1) | |||
| for i in range(num_strokes): | |||
| x1 = np.random.randint(w) | |||
| y1 = np.random.randint(h) | |||
| for j in range(1 + np.random.randint(5)): | |||
| angle = 0.01 + np.random.randint(max_angle) | |||
| if i % 2 == 0: | |||
| angle = 2 * 3.1415926 - angle | |||
| length = 10 + np.random.randint(max_length) | |||
| radius = 5 + np.random.randint(max_width) | |||
| x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w) | |||
| y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h) | |||
| if mode == 'line': | |||
| cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius) | |||
| elif mode == 'circle': | |||
| cv2.circle( | |||
| mask, (x1, y1), radius=radius, color=1.0, thickness=-1) | |||
| elif mode == 'square': | |||
| radius = radius // 2 | |||
| mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1 | |||
| x1, y1 = x2, y2 | |||
| return mask | |||
| def make_rectangle_mask(w, | |||
| h, | |||
| margin=10, | |||
| min_size=30, | |||
| max_size=150, | |||
| min_strokes=1, | |||
| max_strokes=4): | |||
| # initialize mask | |||
| mask = np.zeros((h, w), np.float32) | |||
| # draw rectangles | |||
| num_strokes = np.random.randint(min_strokes, max_strokes + 1) | |||
| for i in range(num_strokes): | |||
| box_w = np.random.randint(min_size, max_size) | |||
| box_h = np.random.randint(min_size, max_size) | |||
| x1 = np.random.randint(margin, w - margin - box_w + 1) | |||
| y1 = np.random.randint(margin, h - margin - box_h + 1) | |||
| mask[y1:y1 + box_h, x1:x1 + box_w] = 1 | |||
| return mask | |||
| def make_uncrop(w, h): | |||
| # initialize mask | |||
| mask = np.zeros((h, w), np.float32) | |||
| # randomly halve the image | |||
| side = np.random.choice([0, 1, 2, 3]) | |||
| if side == 0: | |||
| mask[:h // 2, :] = 1 | |||
| elif side == 1: | |||
| mask[h // 2:, :] = 1 | |||
| elif side == 2: | |||
| mask[:, :w // 2] = 1 | |||
| elif side == 2: | |||
| mask[:, w // 2:] = 1 | |||
| return mask | |||
| @@ -0,0 +1,152 @@ | |||
| r"""SVD of linear degradation matrices described in the paper | |||
| ``Denoising Diffusion Restoration Models.'' | |||
| @article{kawar2022denoising, | |||
| title={Denoising Diffusion Restoration Models}, | |||
| author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song}, | |||
| year={2022}, | |||
| journal={arXiv preprint arXiv:2201.11793}, | |||
| } | |||
| """ | |||
| import torch | |||
| __all__ = ['SVD', 'IdentitySVD', 'DenoiseSVD', 'ColorizationSVD'] | |||
| class SVD(object): | |||
| r"""SVD decomposition of a matrix, i.e., H = UDV^T. | |||
| NOTE: assume that all inputs (i.e., h, x) are of shape [B, CHW]. | |||
| """ | |||
| def __init__(self, h): | |||
| self.u, self.d, self.v = torch.svd(h, some=False) | |||
| self.ut = self.u.t() | |||
| self.vt = self.v.t() | |||
| self.d[self.d < 1e-3] = 0 | |||
| def U(self, x): | |||
| return torch.matmul(self.u, x) | |||
| def Ut(self, x): | |||
| return torch.matmul(self.ut, x) | |||
| def V(self, x): | |||
| return torch.matmul(self.v, x) | |||
| def Vt(self, x): | |||
| return torch.matmul(self.vt, x) | |||
| @property | |||
| def D(self): | |||
| return self.d | |||
| def H(self, x): | |||
| return self.U(self.D * self.Vt(x)[:, :self.D.size(0)]) | |||
| def Ht(self, x): | |||
| return self.V(self._pad(self.D * self.Ut(x)[:, :self.D.size(0)])) | |||
| def Hinv(self, x): | |||
| r"""Multiplies x by the pseudo inverse of H. | |||
| """ | |||
| x = self.Ut(x) | |||
| x[:, :self.D.size(0)] = x[:, :self.D.size(0)] / self.D | |||
| return self.V(self._pad(x)) | |||
| def _pad(self, x): | |||
| o = x.new_zeros(x.size(0), self.v.size(0)) | |||
| o[:, :self.u.size(0)] = x.view(x.size(0), -1) | |||
| return o | |||
| def to(self, *args, **kwargs): | |||
| r"""Update the data type and device of UDV matrices. | |||
| """ | |||
| for k, v in self.__dict__.items(): | |||
| if isinstance(v, torch.Tensor): | |||
| setattr(self, k, v.to(*args, **kwargs)) | |||
| return self | |||
| class IdentitySVD(SVD): | |||
| def __init__(self, c, h, w): | |||
| self.d = torch.ones(c * h * w) | |||
| def U(self, x): | |||
| return x.clone() | |||
| def Ut(self, x): | |||
| return x.clone() | |||
| def V(self, x): | |||
| return x.clone() | |||
| def Vt(self, x): | |||
| return x.clone() | |||
| def H(self, x): | |||
| return x.clone() | |||
| def Ht(self, x): | |||
| return x.clone() | |||
| def Hinv(self, x): | |||
| return x.clone() | |||
| def _pad(self, x): | |||
| return x.clone() | |||
| class DenoiseSVD(SVD): | |||
| def __init__(self, c, h, w): | |||
| self.num_entries = c * h * w | |||
| self.d = torch.ones(self.num_entries) | |||
| def U(self, x): | |||
| return x.clone() | |||
| def Ut(self, x): | |||
| return x.clone() | |||
| def V(self, x): | |||
| return x.clone() | |||
| def Vt(self, x): | |||
| return x.clone() | |||
| def _pad(self, x): | |||
| return x.clone() | |||
| class ColorizationSVD(SVD): | |||
| def __init__(self, c, h, w): | |||
| self.color_dim = c | |||
| self.num_pixels = h * w | |||
| self.u, self.d, self.v = torch.svd(torch.ones(1, c) / c, some=False) | |||
| self.vt = self.v.t() | |||
| def U(self, x): | |||
| return self.u[0, 0] * x | |||
| def Ut(self, x): | |||
| return self.u[0, 0] * x | |||
| def V(self, x): | |||
| return torch.einsum('ij,bjn->bin', self.v, | |||
| x.view(x.size(0), self.color_dim, | |||
| self.num_pixels)).flatten(1) | |||
| def Vt(self, x): | |||
| return torch.einsum('ij,bjn->bin', self.vt, | |||
| x.view(x.size(0), self.color_dim, | |||
| self.num_pixels)).flatten(1) | |||
| @property | |||
| def D(self): | |||
| return self.d.repeat(self.num_pixels) | |||
| def _pad(self, x): | |||
| o = x.new_zeros(x.size(0), self.color_dim * self.num_pixels) | |||
| o[:, :self.num_pixels] = x | |||
| return o | |||
| @@ -0,0 +1,224 @@ | |||
| import base64 | |||
| import binascii | |||
| import hashlib | |||
| import math | |||
| import os | |||
| import os.path as osp | |||
| import zipfile | |||
| from io import BytesIO | |||
| from multiprocessing.pool import ThreadPool as Pool | |||
| import cv2 | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from PIL import Image | |||
| from .random_color import rand_color | |||
| __all__ = [ | |||
| 'ceil_divide', 'to_device', 'rand_name', 'ema', 'parallel', 'unzip', | |||
| 'load_state_dict', 'inverse_indices', 'detect_duplicates', 'md5', 'rope', | |||
| 'format_state', 'breakup_grid', 'viz_anno_geometry', 'image_to_base64' | |||
| ] | |||
| TFS_CLIENT = None | |||
| def ceil_divide(a, b): | |||
| return int(math.ceil(a / b)) | |||
| def to_device(batch, device, non_blocking=False): | |||
| if isinstance(batch, (list, tuple)): | |||
| return type(batch)([to_device(u, device, non_blocking) for u in batch]) | |||
| elif isinstance(batch, dict): | |||
| return type(batch)([(k, to_device(v, device, non_blocking)) | |||
| for k, v in batch.items()]) | |||
| elif isinstance(batch, torch.Tensor): | |||
| return batch.to(device, non_blocking=non_blocking) | |||
| return batch | |||
| def rand_name(length=8, suffix=''): | |||
| name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') | |||
| if suffix: | |||
| if not suffix.startswith('.'): | |||
| suffix = '.' + suffix | |||
| name += suffix | |||
| return name | |||
| @torch.no_grad() | |||
| def ema(net_ema, net, beta, copy_buffer=False): | |||
| assert 0.0 <= beta <= 1.0 | |||
| for p_ema, p in zip(net_ema.parameters(), net.parameters()): | |||
| p_ema.copy_(p.lerp(p_ema, beta)) | |||
| if copy_buffer: | |||
| for b_ema, b in zip(net_ema.buffers(), net.buffers()): | |||
| b_ema.copy_(b) | |||
| def parallel(func, args_list, num_workers=32, timeout=None): | |||
| assert isinstance(args_list, list) | |||
| if not isinstance(args_list[0], tuple): | |||
| args_list = [(args, ) for args in args_list] | |||
| if num_workers == 0: | |||
| return [func(*args) for args in args_list] | |||
| with Pool(processes=num_workers) as pool: | |||
| results = [pool.apply_async(func, args) for args in args_list] | |||
| results = [res.get(timeout=timeout) for res in results] | |||
| return results | |||
| def unzip(filename, dst_dir=None): | |||
| if dst_dir is None: | |||
| dst_dir = osp.dirname(filename) | |||
| with zipfile.ZipFile(filename, 'r') as zip_ref: | |||
| zip_ref.extractall(dst_dir) | |||
| def load_state_dict(module, state_dict, drop_prefix=''): | |||
| # find incompatible key-vals | |||
| src, dst = state_dict, module.state_dict() | |||
| if drop_prefix: | |||
| src = type(src)([ | |||
| (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v) | |||
| for k, v in src.items() | |||
| ]) | |||
| missing = [k for k in dst if k not in src] | |||
| unexpected = [k for k in src if k not in dst] | |||
| unmatched = [ | |||
| k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape | |||
| ] | |||
| # keep only compatible key-vals | |||
| incompatible = set(unexpected + unmatched) | |||
| src = type(src)([(k, v) for k, v in src.items() if k not in incompatible]) | |||
| module.load_state_dict(src, strict=False) | |||
| # report incompatible key-vals | |||
| if len(missing) != 0: | |||
| print(' Missing: ' + ', '.join(missing), flush=True) | |||
| if len(unexpected) != 0: | |||
| print(' Unexpected: ' + ', '.join(unexpected), flush=True) | |||
| if len(unmatched) != 0: | |||
| print(' Shape unmatched: ' + ', '.join(unmatched), flush=True) | |||
| def inverse_indices(indices): | |||
| r"""Inverse map of indices. | |||
| E.g., if A[indices] == B, then B[inv_indices] == A. | |||
| """ | |||
| inv_indices = torch.empty_like(indices) | |||
| inv_indices[indices] = torch.arange(len(indices)).to(indices) | |||
| return inv_indices | |||
| def detect_duplicates(feats, thr=0.9): | |||
| assert feats.ndim == 2 | |||
| # compute simmat | |||
| feats = F.normalize(feats, p=2, dim=1) | |||
| simmat = torch.mm(feats, feats.T) | |||
| simmat.triu_(1) | |||
| torch.cuda.synchronize() | |||
| # detect duplicates | |||
| mask = ~simmat.gt(thr).any(dim=0) | |||
| return torch.where(mask)[0] | |||
| def md5(filename): | |||
| with open(filename, 'rb') as f: | |||
| return hashlib.md5(f.read()).hexdigest() | |||
| def rope(x): | |||
| r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C]. | |||
| """ | |||
| # reshape | |||
| shape = x.shape | |||
| x = x.view(x.size(0), -1, x.size(-1)) | |||
| l, c = x.shape[-2:] | |||
| assert c % 2 == 0 | |||
| half = c // 2 | |||
| # apply rotary position embedding on x | |||
| sinusoid = torch.outer( | |||
| torch.arange(l).to(x), | |||
| torch.pow(10000, -torch.arange(half).to(x).div(half))) | |||
| sin, cos = torch.sin(sinusoid), torch.cos(sinusoid) | |||
| x1, x2 = x.chunk(2, dim=-1) | |||
| x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) | |||
| # reshape back | |||
| return x.view(shape) | |||
| def format_state(state, filename=None): | |||
| r"""For comparing/aligning state_dict. | |||
| """ | |||
| content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()]) | |||
| if filename: | |||
| with open(filename, 'w') as f: | |||
| f.write(content) | |||
| def breakup_grid(img, grid_size): | |||
| r"""The inverse operator of ``torchvision.utils.make_grid``. | |||
| """ | |||
| # params | |||
| nrow = img.height // grid_size | |||
| ncol = img.width // grid_size | |||
| wrow = wcol = 2 # NOTE: use default values here | |||
| # collect grids | |||
| grids = [] | |||
| for i in range(nrow): | |||
| for j in range(ncol): | |||
| x1 = j * grid_size + (j + 1) * wcol | |||
| y1 = i * grid_size + (i + 1) * wrow | |||
| grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size))) | |||
| return grids | |||
| def viz_anno_geometry(item): | |||
| r"""Visualize an annotation item from SmartLabel. | |||
| """ | |||
| if isinstance(item, str): | |||
| item = json.loads(item) | |||
| assert isinstance(item, dict) | |||
| # read image | |||
| orig_img = read_image(item['image_url'], retry=100) | |||
| img = cv2.cvtColor(np.asarray(orig_img), cv2.COLOR_BGR2RGB) | |||
| # loop over geometries | |||
| for geometry in item['sd_result']['items']: | |||
| # params | |||
| poly_img = img.copy() | |||
| color = rand_color() | |||
| points = np.array(geometry['meta']['geometry']).round().astype(int) | |||
| line_color = tuple([int(u * 0.55) for u in color]) | |||
| # draw polygons | |||
| poly_img = cv2.fillPoly(poly_img, pts=[points], color=color) | |||
| poly_img = cv2.polylines( | |||
| poly_img, | |||
| pts=[points], | |||
| isClosed=True, | |||
| color=line_color, | |||
| thickness=2) | |||
| # mixing | |||
| img = np.clip(0.25 * img + 0.75 * poly_img, 0, 255).astype(np.uint8) | |||
| return orig_img, Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |||
| def image_to_base64(img, format='JPEG'): | |||
| buffer = BytesIO() | |||
| img.save(buffer, format=format) | |||
| code = base64.b64encode(buffer.getvalue()).decode('utf-8') | |||
| return code | |||
| @@ -15,6 +15,7 @@ if TYPE_CHECKING: | |||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
| from .image_colorization_pipeline import ImageColorizationPipeline | |||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .image_matting_pipeline import ImageMattingPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| @@ -0,0 +1,325 @@ | |||
| import io | |||
| import os.path as osp | |||
| import sys | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torchvision.transforms as T | |||
| from PIL import Image | |||
| from torchvision.utils import save_image | |||
| import modelscope.models.cv.image_to_image_translation.data as data | |||
| import modelscope.models.cv.image_to_image_translation.models as models | |||
| import modelscope.models.cv.image_to_image_translation.ops as ops | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.image_to_image_translation.model_translation import \ | |||
| UNet | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def save_grid(imgs, filename, nrow=5): | |||
| save_image( | |||
| imgs.clamp(-1, 1), filename, range=(-1, 1), normalize=True, nrow=nrow) | |||
| @PIPELINES.register_module( | |||
| Tasks.image_generation, module_name=Pipelines.image2image_translation) | |||
| class Image2ImageTranslationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading config from {config_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.repetition = 4 | |||
| # load autoencoder model | |||
| ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path) | |||
| logger.info(f'loading autoencoder model from {ae_model_path}') | |||
| self.autoencoder = models.VQAutoencoder( | |||
| dim=self.cfg.Params.ae.ae_dim, | |||
| z_dim=self.cfg.Params.ae.ae_z_dim, | |||
| dim_mult=self.cfg.Params.ae.ae_dim_mult, | |||
| attn_scales=self.cfg.Params.ae.ae_attn_scales, | |||
| codebook_size=self.cfg.Params.ae.ae_codebook_size).eval( | |||
| ).requires_grad_(False).to(self._device) # noqa E123 | |||
| self.autoencoder.load_state_dict( | |||
| torch.load(ae_model_path, map_location=self._device)) | |||
| logger.info('load autoencoder model done') | |||
| # load palette model | |||
| palette_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading palette model from {palette_model_path}') | |||
| self.palette = UNet( | |||
| resolution=self.cfg.Params.unet.unet_resolution, | |||
| in_dim=self.cfg.Params.unet.unet_in_dim, | |||
| dim=self.cfg.Params.unet.unet_dim, | |||
| context_dim=self.cfg.Params.unet.unet_context_dim, | |||
| out_dim=self.cfg.Params.unet.unet_out_dim, | |||
| dim_mult=self.cfg.Params.unet.unet_dim_mult, | |||
| num_heads=self.cfg.Params.unet.unet_num_heads, | |||
| head_dim=None, | |||
| num_res_blocks=self.cfg.Params.unet.unet_res_blocks, | |||
| attn_scales=self.cfg.Params.unet.unet_attn_scales, | |||
| num_classes=self.cfg.Params.unet.unet_num_classes + 1, | |||
| dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_( | |||
| False).to(self._device) | |||
| self.palette.load_state_dict( | |||
| torch.load(palette_model_path, map_location=self._device)) | |||
| logger.info('load palette model done') | |||
| # diffusion | |||
| logger.info('Initialization diffusion ...') | |||
| betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule, | |||
| self.cfg.Params.diffusion.num_timesteps) | |||
| self.diffusion = ops.GaussianDiffusion( | |||
| betas=betas, | |||
| mean_type=self.cfg.Params.diffusion.mean_type, | |||
| var_type=self.cfg.Params.diffusion.var_type, | |||
| loss_type=self.cfg.Params.diffusion.loss_type, | |||
| rescale_timesteps=False) | |||
| self.transforms = T.Compose([ | |||
| data.PadToSquare(), | |||
| T.Resize( | |||
| self.cfg.DATA.scale_size, | |||
| interpolation=T.InterpolationMode.BICUBIC), | |||
| T.ToTensor(), | |||
| T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) | |||
| ]) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if len(input) == 3: # colorization | |||
| _, input_type, save_path = input | |||
| elif len(input) == 4: # uncropping or in-painting | |||
| _, meta, input_type, save_path = input | |||
| if input_type == 0: # uncropping | |||
| assert meta in ['up', 'down', 'left', 'right'] | |||
| direction = meta | |||
| list_ = [] | |||
| for i in range(len(input) - 2): | |||
| input_img = input[i] | |||
| if input_img in ['up', 'down', 'left', 'right']: | |||
| continue | |||
| if isinstance(input_img, str): | |||
| if input_type == 2 and i == 0: | |||
| logger.info('Loading image by origin way ... ') | |||
| bytes = File.read(input_img) | |||
| img = Image.open(io.BytesIO(bytes)) | |||
| assert len(img.split()) == 4 | |||
| else: | |||
| img = load_image(input_img) | |||
| elif isinstance(input_img, PIL.Image.Image): | |||
| img = input_img.convert('RGB') | |||
| elif isinstance(input_img, np.ndarray): | |||
| if len(input_img.shape) == 2: | |||
| input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR) | |||
| img = input_img[:, :, ::-1] | |||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| list_.append(img) | |||
| img_list = [] | |||
| if input_type != 2: | |||
| for img in list_: | |||
| img = self.transforms(img) | |||
| imgs = torch.unsqueeze(img, 0) | |||
| imgs = imgs.to(self._device) | |||
| img_list.append(imgs) | |||
| elif input_type == 2: | |||
| mask, masked_img = list_[0], list_[1] | |||
| img = self.transforms(masked_img.convert('RGB')) | |||
| mask = torch.from_numpy( | |||
| np.array( | |||
| mask.resize((img.shape[2], img.shape[1])), | |||
| dtype=np.float32)[:, :, -1] / 255.0).unsqueeze(0) | |||
| img = (1 - mask) * img + mask * torch.randn_like(img).clamp_(-1, 1) | |||
| imgs = img.unsqueeze(0).to(self._device) | |||
| b, c, h, w = imgs.shape | |||
| y = torch.LongTensor([self.cfg.Classes.class_id]).to(self._device) | |||
| if input_type == 0: | |||
| assert len(img_list) == 1 | |||
| result = { | |||
| 'image_data': img_list[0], | |||
| 'c': c, | |||
| 'h': h, | |||
| 'w': w, | |||
| 'direction': direction, | |||
| 'type': input_type, | |||
| 'y': y, | |||
| 'save_path': save_path | |||
| } | |||
| elif input_type == 1: | |||
| assert len(img_list) == 1 | |||
| result = { | |||
| 'image_data': img_list[0], | |||
| 'c': c, | |||
| 'h': h, | |||
| 'w': w, | |||
| 'type': input_type, | |||
| 'y': y, | |||
| 'save_path': save_path | |||
| } | |||
| elif input_type == 2: | |||
| result = { | |||
| 'image_data': imgs, | |||
| # 'image_mask': mask, | |||
| 'c': c, | |||
| 'h': h, | |||
| 'w': w, | |||
| 'type': input_type, | |||
| 'y': y, | |||
| 'save_path': save_path | |||
| } | |||
| return result | |||
| @torch.no_grad() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| type_ = input['type'] | |||
| if type_ == 0: | |||
| # Uncropping | |||
| img = input['image_data'] | |||
| direction = input['direction'] | |||
| y = input['y'] | |||
| # fix seed | |||
| torch.manual_seed(1 * 8888) | |||
| torch.cuda.manual_seed(1 * 8888) | |||
| logger.info(f'Processing {direction} uncropping') | |||
| img = img.clone() | |||
| i_y = y.repeat(self.repetition, 1) | |||
| if direction == 'up': | |||
| img[:, :, input['h'] // 2:, :] = torch.randn_like( | |||
| img[:, :, input['h'] // 2:, :]) | |||
| elif direction == 'down': | |||
| img[:, :, :input['h'] // 2, :] = torch.randn_like( | |||
| img[:, :, :input['h'] // 2, :]) | |||
| elif direction == 'left': | |||
| img[:, :, :, | |||
| input['w'] // 2:] = torch.randn_like(img[:, :, :, | |||
| input['w'] // 2:]) | |||
| elif direction == 'right': | |||
| img[:, :, :, :input['w'] // 2] = torch.randn_like( | |||
| img[:, :, :, :input['w'] // 2]) | |||
| i_concat = self.autoencoder.encode(img).repeat( | |||
| self.repetition, 1, 1, 1) | |||
| # sample images | |||
| x0 = self.diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(i_concat), | |||
| model=self.palette, | |||
| model_kwargs=[{ | |||
| 'y': i_y, | |||
| 'concat': i_concat | |||
| }, { | |||
| 'y': | |||
| torch.full_like(i_y, | |||
| self.cfg.Params.unet.unet_num_classes), | |||
| 'concat': | |||
| i_concat | |||
| }], | |||
| guide_scale=1.0, | |||
| clamp=None, | |||
| ddim_timesteps=50, | |||
| eta=1.0) | |||
| i_gen_imgs = self.autoencoder.decode(x0) | |||
| save_grid(i_gen_imgs, input['save_path'], nrow=4) | |||
| return {OutputKeys.OUTPUT_IMG: i_gen_imgs} | |||
| elif type_ == 1: | |||
| # Colorization # | |||
| img = input['image_data'] | |||
| y = input['y'] | |||
| # fix seed | |||
| torch.manual_seed(1 * 8888) | |||
| torch.cuda.manual_seed(1 * 8888) | |||
| logger.info('Processing Colorization') | |||
| img = img.clone() | |||
| img = img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) | |||
| i_concat = self.autoencoder.encode(img).repeat( | |||
| self.repetition, 1, 1, 1) | |||
| i_y = y.repeat(self.repetition, 1) | |||
| # sample images | |||
| x0 = self.diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(i_concat), | |||
| model=self.palette, | |||
| model_kwargs=[{ | |||
| 'y': i_y, | |||
| 'concat': i_concat | |||
| }, { | |||
| 'y': | |||
| torch.full_like(i_y, | |||
| self.cfg.Params.unet.unet_num_classes), | |||
| 'concat': | |||
| i_concat | |||
| }], | |||
| guide_scale=1.0, | |||
| clamp=None, | |||
| ddim_timesteps=50, | |||
| eta=0.0) | |||
| i_gen_imgs = self.autoencoder.decode(x0) | |||
| save_grid(i_gen_imgs, input['save_path'], nrow=4) | |||
| return {OutputKeys.OUTPUT_IMG: i_gen_imgs} | |||
| elif type_ == 2: | |||
| # Combination # | |||
| logger.info('Processing Combination') | |||
| # prepare inputs | |||
| img = input['image_data'] | |||
| concat = self.autoencoder.encode(img).repeat( | |||
| self.repetition, 1, 1, 1) | |||
| y = torch.LongTensor([126]).unsqueeze(0).to(self._device).repeat( | |||
| self.repetition, 1) | |||
| # sample images | |||
| x0 = self.diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(concat), | |||
| model=self.palette, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': concat | |||
| }, { | |||
| 'y': | |||
| torch.full_like(y, self.cfg.Params.unet.unet_num_classes), | |||
| 'concat': | |||
| concat | |||
| }], | |||
| guide_scale=1.0, | |||
| clamp=None, | |||
| ddim_timesteps=50, | |||
| eta=1.0) | |||
| i_gen_imgs = self.autoencoder.decode(x0) | |||
| save_grid(i_gen_imgs, input['save_path'], nrow=4) | |||
| return {OutputKeys.OUTPUT_IMG: i_gen_imgs} | |||
| else: | |||
| raise TypeError( | |||
| f'input type should be 0 (Uncropping), 1 (Colorization), 2 (Combation)' | |||
| f' but got {type_}') | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import shutil | |||
| import unittest | |||
| from modelscope.fileio import File | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class Image2ImageTranslationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| r"""We provide three translation modes, i.e., uncropping, colorization and combination. | |||
| You can pass the following parameters for different mode. | |||
| 1. Uncropping Mode: | |||
| result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 'left', 0, 'result.jpg')) | |||
| 2. Colorization Mode: | |||
| result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 1, 'result.jpg')) | |||
| 3. Combination Mode: | |||
| just like the following code. | |||
| """ | |||
| img2img_gen_pipeline = pipeline( | |||
| Tasks.image_generation, | |||
| model='damo/cv_latent_diffusion_image2image_translation') | |||
| result = img2img_gen_pipeline( | |||
| ('data/test/images/img2img_input_mask.png', | |||
| 'data/test/images/img2img_input_masked_img.png', 2, | |||
| 'result.jpg')) # combination mode | |||
| print(f'output: {result}.') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||