diff --git a/data/test/images/img2img_input.jpg b/data/test/images/img2img_input.jpg new file mode 100644 index 00000000..2da79e75 --- /dev/null +++ b/data/test/images/img2img_input.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e4cbf844cd16a892a7d2f2764b1537c346675d3b0145016d6836441ba907366 +size 9195 diff --git a/data/test/images/img2img_input_mask.png b/data/test/images/img2img_input_mask.png new file mode 100644 index 00000000..131fc37a --- /dev/null +++ b/data/test/images/img2img_input_mask.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33b3d3076e191fa92511bf69fa76e1222b3b3be0049e711c948a1218b587510c +size 4805 diff --git a/data/test/images/img2img_input_masked_img.png b/data/test/images/img2img_input_masked_img.png new file mode 100644 index 00000000..7f7c256b --- /dev/null +++ b/data/test/images/img2img_input_masked_img.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99c2b02a927b86ff194287ea4c5a05349dd800cff2b523212d1dad378c252feb +size 103334 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index b57b5734..3e31f422 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -77,6 +77,7 @@ class Pipelines(object): face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' + image2image_translation = 'image-to-image-translation' live_category = 'live-category' video_category = 'video-category' diff --git a/modelscope/models/cv/image_to_image_translation/data/__init__.py b/modelscope/models/cv/image_to_image_translation/data/__init__.py new file mode 100644 index 00000000..72450016 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/data/__init__.py @@ -0,0 +1 @@ +from .transforms import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/data/transforms.py b/modelscope/models/cv/image_to_image_translation/data/transforms.py new file mode 100644 index 00000000..5376d813 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/data/transforms.py @@ -0,0 +1,121 @@ +import math +import random + +import torchvision.transforms.functional as TF +from PIL import Image, ImageFilter + +__all__ = [ + 'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate', + 'RandomGaussianBlur', 'RandomCrop' +] + + +class Identity(object): + + def __call__(self, *args): + if len(args) == 0: + return None + elif len(args) == 1: + return args[0] + else: + return args + + +class PadToSquare(object): + + def __init__(self, fill=(255, 255, 255)): + self.fill = fill + + def __call__(self, img): + w, h = img.size + if w != h: + if w > h: + t = (w - h) // 2 + b = w - h - t + padding = (0, t, 0, b) + else: + left = (h - w) // 2 + right = h - w - l + padding = (left, 0, right, 0) + img = TF.pad(img, padding, fill=self.fill) + return img + + +class RandomScale(object): + + def __init__(self, + min_scale=0.5, + max_scale=2.0, + min_ratio=0.8, + max_ratio=1.25): + self.min_scale = min_scale + self.max_scale = max_scale + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def __call__(self, img): + w, h = img.size + scale = 2**random.uniform( + math.log2(self.min_scale), math.log2(self.max_scale)) + ratio = 2**random.uniform( + math.log2(self.min_ratio), math.log2(self.max_ratio)) + ow = int(w * scale * math.sqrt(ratio)) + oh = int(h * scale / math.sqrt(ratio)) + img = img.resize((ow, oh), Image.BILINEAR) + return img + + +class RandomRotate(object): + + def __init__(self, + min_angle=-10.0, + max_angle=10.0, + padding=(255, 255, 255), + p=0.5): + self.min_angle = min_angle + self.max_angle = max_angle + self.padding = padding + self.p = p + + def __call__(self, img): + if random.random() < self.p: + angle = random.uniform(self.min_angle, self.max_angle) + img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding) + return img + + +class RandomGaussianBlur(object): + + def __init__(self, radius=5, p=0.5): + self.radius = radius + self.p = p + + def __call__(self, img): + if random.random() < self.p: + img = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) + return img + + +class RandomCrop(object): + + def __init__(self, size, padding=(255, 255, 255)): + self.size = size + self.padding = padding + + def __call__(self, img): + # pad + w, h = img.size + pad_w = max(0, self.size - w) + pad_h = max(0, self.size - h) + if pad_w > 0 or pad_h > 0: + half_w = pad_w // 2 + half_h = pad_h // 2 + pad = (half_w, half_h, pad_w - half_w, pad_h - half_h) + img = TF.pad(img, pad, fill=self.padding) + + # crop + w, h = img.size + x1 = random.randint(0, w - self.size) + y1 = random.randint(0, h - self.size) + img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) + return img diff --git a/modelscope/models/cv/image_to_image_translation/model_translation.py b/modelscope/models/cv/image_to_image_translation/model_translation.py new file mode 100644 index 00000000..722b175d --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/model_translation.py @@ -0,0 +1,323 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['UNet'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, scale_factor=1.0): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + if self.scale_factor == 2.0: + x = F.interpolate(x, scale_factor=2, mode='nearest') + elif self.scale_factor == 0.5: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.embedding = nn.Sequential(nn.SiLU(), + nn.Linear(embed_dim, out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, y): + identity = x + x = self.layer1(x) + x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1) + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class MultiHeadAttention(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0): + assert dim % num_heads == 0 + assert context_dim is None or context_dim % num_heads == 0 + context_dim = context_dim or dim + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = math.pow(self.head_dim, -0.25) + + # layers + self.q = nn.Linear(dim, dim, bias=False) + self.k = nn.Linear(context_dim, dim, bias=False) + self.v = nn.Linear(context_dim, dim, bias=False) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None): + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # compute attention + attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, -1, n * c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class GLU(nn.Module): + + def __init__(self, in_dim, out_dim): + super(GLU, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.proj = nn.Linear(in_dim, out_dim * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class TransformerBlock(nn.Module): + + def __init__(self, dim, context_dim, num_heads, dropout=0.0): + super(TransformerBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + # input + self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(dim, dim, 1) + + # self attention + self.norm2 = nn.LayerNorm(dim) + self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout) + + # cross attention + self.norm3 = nn.LayerNorm(dim) + self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads, + dropout) + + # ffn + self.norm4 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim)) + + # output + self.conv2 = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.conv2.weight) + + def forward(self, x, context): + b, c, h, w = x.size() + identity = x + + # input + x = self.norm1(x) + x = self.conv1(x).view(b, c, -1).transpose(1, 2) + + # attention + x = x + self.self_attn(self.norm2(x)) + x = x + self.cross_attn(self.norm3(x), context) + x = x + self.ffn(self.norm4(x)) + + # output + x = x.transpose(1, 2).view(b, c, h, w) + x = self.conv2(x) + return x + identity + + +class UNet(nn.Module): + + def __init__(self, + resolution=64, + in_dim=3, + dim=192, + context_dim=512, + out_dim=3, + dim_mult=[1, 2, 3, 5], + num_heads=1, + head_dim=None, + num_res_blocks=2, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + num_classes=1001, + dropout=0.0): + embed_dim = dim * 4 + super(UNet, self).__init__() + self.resolution = resolution + self.in_dim = in_dim + self.dim = dim + self.context_dim = context_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.num_classes = num_classes + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.label_embedding = nn.Embedding(num_classes, context_dim) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList( + [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + self.encoder.append( + nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)) + shortcut_dims.append(out_dim) + scale /= 2.0 + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, dropout), + TransformerBlock(out_dim, context_dim, num_heads), + ResidualBlock(out_dim, embed_dim, out_dim, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, dropout) + ]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads, + dropout)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + block.append( + nn.Sequential( + Resample(scale_factor=2.0), + nn.Conv2d(out_dim, out_dim, 3, padding=1))) + scale *= 2.0 + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, concat=None): + # embeddings + if concat is not None: + x = torch.cat([x, concat], dim=1) + t = self.time_embedding(sinusoidal_embedding(t, self.dim)) + y = self.label_embedding(y) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, t, y) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, t, y) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, t, y) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, t, y): + if isinstance(module, ResidualBlock): + x = module(x, t) + elif isinstance(module, TransformerBlock): + x = module(x, y) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, t, y) + else: + x = module(x) + return x diff --git a/modelscope/models/cv/image_to_image_translation/models/__init__.py b/modelscope/models/cv/image_to_image_translation/models/__init__.py new file mode 100644 index 00000000..322d78f2 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/models/__init__.py @@ -0,0 +1,2 @@ +from .autoencoder import * # noqa F403 +from .clip import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/models/autoencoder.py b/modelscope/models/cv/image_to_image_translation/models/autoencoder.py new file mode 100644 index 00000000..181472de --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/models/autoencoder.py @@ -0,0 +1,412 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator'] + + +def group_norm(dim): + return nn.GroupNorm(32, dim, eps=1e-6, affine=True) + + +class Resample(nn.Module): + + def __init__(self, dim, scale_factor): + super(Resample, self).__init__() + self.dim = dim + self.scale_factor = scale_factor + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(dim, dim, 3, padding=1)) + elif scale_factor == 0.5: + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=2, padding=0)) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + group_norm(in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim), + nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Conv2d(in_dim, out_dim, + 1) if in_dim != out_dim else nn.Identity() + + # zero out the last layer params + nn.init.zeros_(self.residual[-1].weight) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class AttentionBlock(nn.Module): + + def __init__(self, dim): + super(AttentionBlock, self).__init__() + self.dim = dim + self.scale = math.pow(dim, -0.25) + + # layers + self.norm = group_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, h, w = x.size() + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1) + + # compute attention + attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.einsum('bij,bcj->bci', attn, v) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Encoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Encoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + downsamples.append(Resample(out_dim, scale_factor=0.5)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.downsamples(x) + x = self.middle(x) + x = self.head(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Decoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + upsamples.append(Resample(out_dim, scale_factor=2.0)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, 3, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.middle(x) + x = self.upsamples(x) + x = self.head(x) + return x + + +class VectorQuantizer(nn.Module): + + def __init__(self, codebook_size=8192, z_dim=3, beta=0.25): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size + self.z_dim = z_dim + self.beta = beta + + # init codebook + eps = math.sqrt(1.0 / codebook_size) + self.codebook = nn.Parameter( + torch.empty(codebook_size, z_dim).uniform_(-eps, eps)) + + def forward(self, z): + # preprocess + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + + # quantization + with torch.no_grad(): + tokens = torch.cdist(flatten, self.codebook).argmin(dim=1) + quantized = F.embedding(tokens, + self.codebook).view(b, h, w, + c).permute(0, 3, 1, 2) + + # compute loss + codebook_loss = F.mse_loss(quantized, z.detach()) + commitment_loss = F.mse_loss(quantized.detach(), z) + loss = codebook_loss + self.beta * commitment_loss + + # perplexity + counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) + # dist.all_reduce(counts) + p = counts / counts.sum() + perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) + + # postprocess + tokens = tokens.view(b, h, w) + quantized = z + (quantized - z).detach() + return quantized, tokens, loss, perplexity + + +class VQAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0, + codebook_size=8192, + beta=0.25): + super(VQAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.codebook_size = codebook_size + self.beta = beta + + # blocks + self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim, z_dim, 1) + self.quantizer = VectorQuantizer(codebook_size, z_dim, beta) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + z = self.encoder(x) + z = self.conv1(z) + z, tokens, loss, perplexity = self.quantizer(z) + z = self.conv2(z) + x = self.decoder(z) + return x, tokens, loss, perplexity + + def encode(self, imgs): + z = self.encoder(imgs) + z = self.conv1(z) + return z + + def decode(self, z): + r"""Absort the quantizer in the decoder. + """ + z = self.quantizer(z)[0] + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + @torch.no_grad() + def encode_to_tokens(self, imgs): + # preprocess + z = self.encoder(imgs) + z = self.conv1(z) + + # quantization + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1) + return tokens.view(b, -1) + + @torch.no_grad() + def decode_from_tokens(self, tokens): + # dequantization + z = F.embedding(tokens, self.quantizer.codebook) + + # postprocess + b, l, c = z.size() + h = w = int(math.sqrt(l)) + z = z.view(b, h, w, c).permute(0, 3, 1, 2) + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + +class KLAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(KLAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # blocks + self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x = self.decode(z) + return x, mu, log_var + + def encode(self, x): + x = self.encoder(x) + mu, log_var = self.conv1(x).chunk(2, dim=1) + return mu, log_var + + def decode(self, z): + x = self.conv2(z) + x = self.decoder(x) + return x + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + +class PatchDiscriminator(nn.Module): + + def __init__(self, in_dim=3, dim=64, num_layers=3): + super(PatchDiscriminator, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.num_layers = num_layers + + # params + dims = [dim * min(8, 2**u) for u in range(num_layers + 1)] + + # layers + layers = [ + nn.Conv2d(in_dim, dim, 4, stride=2, padding=1), + nn.LeakyReLU(0.2) + ] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + stride = 1 if i == num_layers - 1 else 2 + layers += [ + nn.Conv2d( + in_dim, out_dim, 4, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_dim), + nn.LeakyReLU(0.2) + ] + layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)] + self.layers = nn.Sequential(*layers) + + # initialize weights + self.apply(self.init_weights) + + def forward(self, x): + return self.layers(x) + + def init_weights(self, m): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0.0, 0.02) + elif isinstance(m, nn.BatchNorm2d): + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) diff --git a/modelscope/models/cv/image_to_image_translation/models/clip.py b/modelscope/models/cv/image_to_image_translation/models/clip.py new file mode 100644 index 00000000..35d9d882 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/models/clip.py @@ -0,0 +1,418 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather + +__all__ = [ + 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', + 'clip_vit_l_14_336px', 'clip_vit_h_16' +] + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # gather features from all ranks + full_xi = ops.diff_all_gather(xi) + full_xt = ops.diff_all_gather(xt) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, full_xt.t()) + logits_t2i = scale * torch.mm(xt, full_xi.t()) + + # labels + labels = torch.arange( + len(xi) * ops.get_rank(), + len(xi) * (ops.get_rank() + 1), + dtype=torch.long, + device=xi.device) + return logits_i2t, logits_t2i, labels + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + def fp16(self): + return self.apply(to_fp16) + + +def clip_vit_b_32(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_b_16(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_l_14(**kwargs): + return CLIP( + embed_dim=768, + image_size=224, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_l_14_336px(**kwargs): + return CLIP( + embed_dim=768, + image_size=336, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_h_16(**kwargs): + return CLIP( + embed_dim=1024, + image_size=256, + patch_size=16, + vision_dim=1280, + vision_heads=16, + vision_layers=32, + text_dim=1024, + text_heads=16, + text_layers=24, + **kwargs) diff --git a/modelscope/models/cv/image_to_image_translation/ops/__init__.py b/modelscope/models/cv/image_to_image_translation/ops/__init__.py new file mode 100644 index 00000000..59082d72 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/__init__.py @@ -0,0 +1,8 @@ +from .degradation import * # noqa F403 +from .diffusion import * # noqa F403 +from .losses import * # noqa F403 +from .metrics import * # noqa F403 +from .random_color import * # noqa F403 +from .random_mask import * # noqa F403 +from .svd import * # noqa F403 +from .utils import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/ops/apps.py b/modelscope/models/cv/image_to_image_translation/ops/apps.py new file mode 100644 index 00000000..ee4be489 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/apps.py @@ -0,0 +1,663 @@ +# APPs that facilitate the use of pretrained neural networks. + +import os.path as osp + +import artist.data as data +import artist.models as models +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn.functional as F +import torchvision.transforms as T +from artist import DOWNLOAD_TO_CACHE +from PIL import Image +from torch.utils.data import DataLoader, Dataset + +from .utils import parallel, read_image + +__all__ = [ + 'FeatureExtractor', 'Classifier', 'Text2Image', 'Sole2Shoe', 'ImageParser', + 'TextImageMatch', 'taobao_feature_extractor', 'singleton_classifier', + 'orientation_classifier', 'fashion_text2image', 'mindalle_text2image', + 'sole2shoe', 'sole_parser', 'sod_foreground_parser', + 'fashion_text_image_match' +] + + +class ImageFolder(Dataset): + + def __init__(self, paths, transforms=None): + self.paths = paths + self.transforms = transforms + + def __getitem__(self, index): + img = read_image(self.paths[index]) + if img.mode != 'RGB': + img = img.convert('RGB') + if self.transforms is not None: + img = self.transforms(img) + return img + + def __len__(self): + return len(self.paths) + + +class FeatureExtractor(object): + + def __init__( + self, + model='InceptionV1', + checkpoint='models/inception-v1/1218shoes.v9_7.140.0.1520000', + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + batch_size=64, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.resolution = resolution + self.batch_size = batch_size + self.device = device + + # init model + self.net = getattr( + models, + model)(num_classes=None).eval().requires_grad_(False).to(device) + self.net.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(resolution), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, num_workers=0): + r"""imgs: Either a PIL.Image or a list of PIL.Image instances. + """ + # preprocess + if isinstance(imgs, Image.Image): + imgs = [imgs] + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + feats = [] + for batch in imgs.split(self.batch_size, dim=0): + batch = batch.to(self.device, non_blocking=True) + feats.append(self.net(batch)) + return torch.cat(feats, dim=0) + + def batch_process(self, paths): + # init dataloader + dataloader = DataLoader( + dataset=ImageFolder(paths, self.transforms), + batch_size=self.batch_size, + shuffle=False, + drop_last=False, + pin_memory=True, + num_workers=8, + prefetch_factor=2) + + # forward + feats = [] + for step, batch in enumerate(dataloader, 1): + print(f'Step: {step}/{len(dataloader)}', flush=True) + batch = batch.to(self.device, non_blocking=True) + feats.append(self.net(batch)) + return torch.cat(feats) + + +class Classifier(object): + + def __init__( + self, + model='InceptionV1', + checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth', + num_classes=1, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + batch_size=64, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.num_classes = num_classes + self.resolution = resolution + self.batch_size = batch_size + self.device = device + + # init model + self.net = getattr(models, model)( + num_classes=num_classes).eval().requires_grad_(False).to(device) + self.net.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(resolution), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, num_workers=0): + r"""imgs: Either a PIL.Image or a list of PIL.Image instances. + """ + # preprocess + if isinstance(imgs, Image.Image): + imgs = [imgs] + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + scores = [] + for batch in imgs.split(self.batch_size, dim=0): + batch = batch.to(self.device, non_blocking=True) + logits = self.net(batch) + scores.append(logits.sigmoid() if self.num_classes == # noqa W504 + 1 else logits.softmax(dim=1)) + return torch.cat(scores, dim=0) + + +class Text2Image(object): + + def __init__( + self, + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + gpt_txt_vocab_size=21128, + gpt_txt_seq_len=64, + gpt_img_seq_len=1024, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint='models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64), + batch_size=16, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.tokenizer = tokenizer + self.batch_size = batch_size + self.device = device + + # init VQGAN model + self.vqgan = models.VQGAN( + dim=vqgan_dim, + z_dim=vqgan_z_dim, + dim_mult=vqgan_dim_mult, + num_res_blocks=vqgan_num_res_blocks, + attn_scales=vqgan_attn_scales, + codebook_size=vqgan_codebook_size, + beta=vqgan_beta).eval().requires_grad_(False).to(device) + self.vqgan.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device)) + + # init GPT model + self.gpt = models.Seq2SeqGPT( + src_vocab_size=gpt_txt_vocab_size, + tar_vocab_size=vqgan_codebook_size, + src_seq_len=gpt_txt_seq_len, + tar_seq_len=gpt_img_seq_len, + dim=gpt_dim, + num_heads=gpt_num_heads, + num_layers=gpt_num_layers).eval().requires_grad_(False).to(device) + self.gpt.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device)) + + def __call__(self, + txts, + top_k=64, + top_p=None, + temperature=0.6, + use_fp16=True): + # preprocess + if isinstance(txts, str): + txts = [txts] + assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str) + txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts]) + + # forward + out_imgs = [] + for batch in txt_tokens.split(self.batch_size, dim=0): + # sample + batch = batch.to(self.device, non_blocking=True) + with amp.autocast(enabled=use_fp16): + img_tokens = self.gpt.sample(batch, top_k, top_p, temperature) + + # decode + imgs = self.vqgan.decode_from_tokens(img_tokens) + imgs = self._whiten_borders(imgs) + imgs = imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute( + 0, 2, 3, 1).cpu().numpy().astype(np.uint8) + imgs = [Image.fromarray(u) for u in imgs] + + # append + out_imgs += imgs + return out_imgs + + def _whiten_borders(self, imgs): + r"""Remove border artifacts. + """ + imgs[:, :, :18, :] = 1 + imgs[:, :, :, :18] = 1 + imgs[:, :, -18:, :] = 1 + imgs[:, :, :, -18:] = 1 + return imgs + + +class Sole2Shoe(object): + + def __init__( + self, + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + src_resolution=256, + tar_resolution=512, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth', + batch_size=12, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.batch_size = batch_size + self.device = device + src_seq_len = (src_resolution // 16)**2 + tar_seq_len = (tar_resolution // 16)**2 + + # init VQGAN model + self.vqgan = models.VQGAN( + dim=vqgan_dim, + z_dim=vqgan_z_dim, + dim_mult=vqgan_dim_mult, + num_res_blocks=vqgan_num_res_blocks, + attn_scales=vqgan_attn_scales, + codebook_size=vqgan_codebook_size, + beta=vqgan_beta).eval().requires_grad_(False).to(device) + self.vqgan.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device)) + + # init GPT model + self.gpt = models.Seq2SeqGPT( + src_vocab_size=vqgan_codebook_size, + tar_vocab_size=vqgan_codebook_size, + src_seq_len=src_seq_len, + tar_seq_len=tar_seq_len, + dim=gpt_dim, + num_heads=gpt_num_heads, + num_layers=gpt_num_layers).eval().requires_grad_(False).to(device) + self.gpt.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device)) + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(src_resolution), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ]) + + def __call__(self, + sole_imgs, + top_k=64, + top_p=None, + temperature=0.6, + use_fp16=True, + num_workers=0): + # preprocess + if isinstance(sole_imgs, Image.Image): + sole_imgs = [sole_imgs] + assert isinstance(sole_imgs, (tuple, list)) and isinstance( + sole_imgs[0], Image.Image) + sole_imgs = torch.stack( + parallel(self.transforms, sole_imgs, num_workers), dim=0) + + # forward + out_imgs = [] + for batch in sole_imgs.split(self.batch_size, dim=0): + # sample + batch = batch.to(self.device) + with amp.autocast(enabled=use_fp16): + sole_tokens = self.vqgan.encode_to_tokens(batch) + shoe_tokens = self.gpt.sample(sole_tokens, top_k, top_p, + temperature) + + # decode + shoe_imgs = self.vqgan.decode_from_tokens(shoe_tokens) + shoe_imgs = self._whiten_borders(shoe_imgs) + shoe_imgs = shoe_imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute( + 0, 2, 3, 1).cpu().numpy().astype(np.uint8) + shoe_imgs = [Image.fromarray(u) for u in shoe_imgs] + + # append + out_imgs += shoe_imgs + return out_imgs + + def _whiten_borders(self, imgs): + r"""Remove border artifacts. + """ + imgs[:, :, :18, :] = 1 + imgs[:, :, :, :18] = 1 + imgs[:, :, -18:, :] = 1 + imgs[:, :, :, -18:] = 1 + return imgs + + +class ImageParser(object): + + def __init__( + self, + model='SPNet', + num_classes=2, + resolution=800, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + model_with_softmax=False, + checkpoint='models/spnet/sole_segmentation_211219.pth', + batch_size=16, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.batch_size = batch_size + self.device = device + + # init model + if checkpoint.endswith('.pt'): + self.net = torch.jit.load( + DOWNLOAD_TO_CACHE(checkpoint)).eval().to(device) + [p.requires_grad_(False) for p in self.net.parameters()] + else: + self.net = getattr(models, model)( + num_classes=num_classes, + pretrained=False).eval().requires_grad_(False).to(device) + self.net.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + self.softmax = (lambda x, dim: x) if model_with_softmax else F.softmax + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(resolution), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, num_workers=0): + # preprocess + if isinstance(imgs, Image.Image): + imgs = [imgs] + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + sizes = [u.size for u in imgs] + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + masks = [] + for batch in imgs.split(self.batch_size, dim=0): + batch = batch.to(self.device, non_blocking=True) + masks.append(self.softmax(self.net(batch), dim=1)) + + # postprocess + masks = torch.cat(masks, dim=0).unsqueeze(1) + masks = [ + F.interpolate(u, v, mode='bilinear', align_corners=False) + for u, v in zip(masks, sizes) + ] + return masks + + +class TextImageMatch(object): + + def __init__( + self, + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=21128, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77), + batch_size=64, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.tokenizer = tokenizer + self.batch_size = batch_size + self.device = device + + # init model + self.clip = models.CLIP( + embed_dim=embed_dim, + image_size=image_size, + patch_size=patch_size, + vision_dim=vision_dim, + vision_heads=vision_heads, + vision_layers=vision_layers, + vocab_size=vocab_size, + text_len=text_len, + text_dim=text_dim, + text_heads=text_heads, + text_layers=text_layers).eval().requires_grad_(False).to(device) + self.clip.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + + # transforms + scale_size = int(image_size * 8 / 7) + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(scale_size), + T.CenterCrop(image_size), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, txts, num_workers=0): + # preprocess + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str) + txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts]) + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + scores = [] + for img_batch, txt_batch in zip( + imgs.split(self.batch_size, dim=0), + txt_tokens.split(self.batch_size, dim=0)): + img_batch = img_batch.to(self.device) + txt_batch = txt_batch.to(self.device) + xi = F.normalize(self.clip.visual(img_batch), p=2, dim=1) + xt = F.normalize(self.clip.textual(txt_batch), p=2, dim=1) + scores.append((xi * xt).sum(dim=1)) + return torch.cat(scores, dim=0) + + +def taobao_feature_extractor(category='shoes', **kwargs): + r"""Pretrained taobao-search feature extractors. + """ + assert category in ['softall', 'shoes', 'bag'] + checkpoint = osp.join( + 'models/inception-v1', { + 'softall': '1214softall_10.10.0.5000', + 'shoes': '1218shoes.v9_7.140.0.1520000', + 'bag': '0926bag.v9_6.29.0.140000' + }[category]) + app = FeatureExtractor( + model='InceptionV1', + checkpoint=checkpoint, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + **kwargs) + return app + + +def singleton_classifier(**kwargs): + r"""Pretrained classifier that finds single-object images. + Supports shoes, apparel, and bag images. + """ + app = Classifier( + model='InceptionV1', + checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth', + num_classes=1, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + **kwargs) + return app + + +def orientation_classifier(**kwargs): + r"""Shoes orientation classifier. + """ + app = Classifier( + model='InceptionV1', + checkpoint='models/classifier/shoes-oriendetect-20211026.pth', + num_classes=1, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + **kwargs) + return app + + +def fashion_text2image(**kwargs): + r"""Fashion text-to-image generator. + Supports shoe and apparel image generation. + """ + app = Text2Image( + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + gpt_txt_vocab_size=21128, + gpt_txt_seq_len=64, + gpt_img_seq_len=1024, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint= # noqa E251 + 'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint= # noqa E251 + 'models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64), + **kwargs) + return app + + +def mindalle_text2image(**kwargs): + r"""Pretrained text2image generator with weights copied from minDALL-E. + """ + app = Text2Image( + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=16384, + vqgan_beta=0.25, + gpt_txt_vocab_size=16384, + gpt_txt_seq_len=64, + gpt_img_seq_len=256, + gpt_dim=1536, + gpt_num_heads=24, + gpt_num_layers=42, + vqgan_checkpoint='models/minDALLE/1.3B_vqgan.pth', + gpt_checkpoint='models/minDALLE/1.3B_gpt.pth', + tokenizer=data.BPETokenizer(length=64), + **kwargs) + return app + + +def sole2shoe(**kwargs): + app = Sole2Shoe( + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + src_resolution=256, + tar_resolution=512, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint= # noqa E251 + 'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth', + **kwargs) + return app + + +def sole_parser(**kwargs): + app = ImageParser( + model='SPNet', + num_classes=2, + resolution=800, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + model_with_softmax=False, + checkpoint='models/spnet/sole_segmentation_211219.pth', + **kwargs) + return app + + +def sod_foreground_parser(**kwargs): + app = ImageParser( + model=None, + num_classes=None, + resolution=448, + mean=[0.488431, 0.466275, 0.403686], + std=[0.222627, 0.21949, 0.22549], + model_with_softmax=True, + checkpoint='models/semseg/sod_model_20201228.pt', + **kwargs) + return app + + +def fashion_text_image_match(**kwargs): + app = TextImageMatch( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=21128, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77), + **kwargs) + return app diff --git a/modelscope/models/cv/image_to_image_translation/ops/degradation.py b/modelscope/models/cv/image_to_image_translation/ops/degradation.py new file mode 100644 index 00000000..c3b3d1df --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/degradation.py @@ -0,0 +1,1074 @@ +import math +import os +import random + +import cv2 +import numpy as np +import scipy +import scipy.stats as stats +import torch +from scipy import ndimage +from scipy.interpolate import interp2d +from scipy.linalg import orth + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' + +__all__ = ['degradation_bsrgan_light', 'degradation_bsrgan'] + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + return np.float32(img / 255.) + + +def single2uint(img): + return np.uint8((img.clip(0, 1) * 255.).round()) + + +def uint162single(img): + return np.float32(img / 65535.) + + +def single2uint16(img): + return np.uint16((img.clip(0, 1) * 65535.).round()) + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, + [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa E126 + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, + [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + # img1 = img1.squeeze() + # img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + # img1 = img1.squeeze() + # img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * # noqa W504 + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * # noqa W504 + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + + 2) * (((absx > 1) * # noqa W504 + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( + 0, P - 1, P).view(1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W + * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W + * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, + j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, + j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot( + np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = stats.multivariate_normal.pdf([cx, cy], + mean=mean, + cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d( + x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10., + noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1 + ) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid( + np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve( + x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur_1(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2 / 4 + wd = wd / 4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2) + else: + k = fspecial('gaussian', + random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize( + img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, + (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs( + L**2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, + img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, + (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10**(2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype( + np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode( + '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, + rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan_light(image, sf=4, isp_model=None): + """ + This is the variant of the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = uint2single(image) + _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + # sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + # hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[ + idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur_1(image, sf=sf) + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, (int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum( + ) # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + elif i == 3: + # downsample3 + image = cv2.resize( + image, (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = single2uint(image) + return image + + +def add_blur_2(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=2 * random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, + wd * random.random()) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror') + return img + + +def degradation_bsrgan(image, sf=4, isp_model=None): + """ + This is the variant of the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = uint2single(image) + _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + # sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + # hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[ + idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur_2(image, sf=sf) + elif i == 1: + image = add_blur_2(image, sf=sf) + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, (int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum( + ) # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + elif i == 3: + # downsample3 + image = cv2.resize( + image, (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = single2uint(image) + return image diff --git a/modelscope/models/cv/image_to_image_translation/ops/diffusion.py b/modelscope/models/cv/image_to_image_translation/ops/diffusion.py new file mode 100644 index 00000000..bcbb6402 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/diffusion.py @@ -0,0 +1,598 @@ +import math + +import torch + +from .losses import discretized_gaussian_log_likelihood, kl_divergence + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + + # fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + def fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( + self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + # no noise when t == 0 + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + assert self.mean_type == 'eps' + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + out = torch.cat( + [ + u_out[:, :3] + guide_scale * # noqa W504 + (y_out[:, :3] - u_out[:, :3]), + y_out[:, 3:] + ], + dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504 + (1 - alphas) * # noqa W504 + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b, c, h, w = x0.size() + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b, c, h, w = x0.size() + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/cv/image_to_image_translation/ops/losses.py b/modelscope/models/cv/image_to_image_translation/ops/losses.py new file mode 100644 index 00000000..23e8d246 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/losses.py @@ -0,0 +1,35 @@ +import math + +import torch + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504 + ((mu1 - mu2)**2) * torch.exp(-logvar2)) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/modelscope/models/cv/image_to_image_translation/ops/metrics.py b/modelscope/models/cv/image_to_image_translation/ops/metrics.py new file mode 100644 index 00000000..4a63c51f --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/metrics.py @@ -0,0 +1,126 @@ +import numpy as np +import scipy.linalg as linalg +import torch + +__all__ = [ + 'get_fid_net', 'get_is_net', 'compute_fid', 'compute_prdc', 'compute_is' +] + + +def get_fid_net(resize_input=True, normalize_input=True): + r"""InceptionV3 network for the evaluation of Fréchet Inception Distance (FID). + + Args: + resize_input: whether or not to resize the input to (299, 299). + normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1). + """ + from artist.models import InceptionV3 + return InceptionV3( + output_blocks=(3, ), + resize_input=resize_input, + normalize_input=normalize_input, + requires_grad=False, + use_fid_inception=True).eval().requires_grad_(False) + + +def get_is_net(resize_input=True, normalize_input=True): + r"""InceptionV3 network for the evaluation of Inception Score (IS). + + Args: + resize_input: whether or not to resize the input to (299, 299). + normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1). + """ + from artist.models import InceptionV3 + return InceptionV3( + output_blocks=(4, ), + resize_input=resize_input, + normalize_input=normalize_input, + requires_grad=False, + use_fid_inception=False).eval().requires_grad_(False) + + +@torch.no_grad() +def compute_fid(real_feats, fake_feats, eps=1e-6): + r"""Compute Fréchet Inception Distance (FID). + + Args: + real_feats: [N, C]. + fake_feats: [N, C]. + """ + # check inputs + if isinstance(real_feats, torch.Tensor): + real_feats = real_feats.cpu().numpy().astype(np.float_) + if isinstance(fake_feats, torch.Tensor): + fake_feats = fake_feats.cpu().numpy().astype(np.float_) + + # real statistics + mu1 = np.mean(real_feats, axis=0) + sigma1 = np.cov(real_feats, rowvar=False) + + # fake statistics + mu2 = np.mean(fake_feats, axis=0) + sigma2 = np.cov(fake_feats, rowvar=False) + + # compute covmean + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + print( + f'FID calculation produces singular product; adding {eps} to diagonal of cov', + flush=True) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + # compute Fréchet distance + diff = mu1 - mu2 + fid = diff.dot(diff) + np.trace(sigma1) + np.trace( + sigma2) - 2 * np.trace(covmean) + return fid.item() + + +@torch.no_grad() +def compute_prdc(real_feats, fake_feats, knn=5): + r"""Compute precision, recall, density, and coverage given two manifolds. + + Args: + real_feats: [N, C]. + fake_feats: [N, C]. + knn: the number of nearest neighbors to consider. + """ + # distances + real_kth = -(-torch.cdist(real_feats, real_feats)).topk( + k=knn, dim=1)[0][:, -1] + fake_kth = -(-torch.cdist(fake_feats, fake_feats)).topk( + k=knn, dim=1)[0][:, -1] + dists = torch.cdist(real_feats, fake_feats) + + # metrics + precision = (dists < real_kth.unsqueeze(1)).any( + dim=0).float().mean().item() + recall = (dists < fake_kth.unsqueeze(0)).any(dim=1).float().mean().item() + density = (dists < real_kth.unsqueeze(1)).float().sum( + dim=0).mean().item() / knn + coverage = (dists.min(dim=1)[0] < real_kth).float().mean().item() + return precision, recall, density, coverage + + +@torch.no_grad() +def compute_is(logits, num_splits=10): + preds = logits.softmax(dim=1).cpu().numpy() + split_scores = [] + for k in range(num_splits): + part = preds[k * (len(logits) // num_splits):(k + 1) + * (len(logits) // num_splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + return np.mean(split_scores), np.std(split_scores) diff --git a/modelscope/models/cv/image_to_image_translation/ops/random_color.py b/modelscope/models/cv/image_to_image_translation/ops/random_color.py new file mode 100644 index 00000000..97e2f848 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/random_color.py @@ -0,0 +1,220 @@ +import colorsys +import random + +__all__ = ['RandomColor', 'rand_color'] + +COLORMAP = { + 'blue': { + 'hue_range': [179, 257], + 'lower_bounds': [[20, 100], [30, 86], [40, 80], [50, 74], [60, 60], + [70, 52], [80, 44], [90, 39], [100, 35]] + }, + 'green': { + 'hue_range': [63, 178], + 'lower_bounds': [[30, 100], [40, 90], [50, 85], [60, 81], [70, 74], + [80, 64], [90, 50], [100, 40]] + }, + 'monochrome': { + 'hue_range': [0, 0], + 'lower_bounds': [[0, 0], [100, 0]] + }, + 'orange': { + 'hue_range': [19, 46], + 'lower_bounds': [[20, 100], [30, 93], [40, 88], [50, 86], [60, 85], + [70, 70], [100, 70]] + }, + 'pink': { + 'hue_range': [283, 334], + 'lower_bounds': [[20, 100], [30, 90], [40, 86], [60, 84], [80, 80], + [90, 75], [100, 73]] + }, + 'purple': { + 'hue_range': [258, 282], + 'lower_bounds': [[20, 100], [30, 87], [40, 79], [50, 70], [60, 65], + [70, 59], [80, 52], [90, 45], [100, 42]] + }, + 'red': { + 'hue_range': [-26, 18], + 'lower_bounds': [[20, 100], [30, 92], [40, 89], [50, 85], [60, 78], + [70, 70], [80, 60], [90, 55], [100, 50]] + }, + 'yellow': { + 'hue_range': [47, 62], + 'lower_bounds': [[25, 100], [40, 94], [50, 89], [60, 86], [70, 84], + [80, 82], [90, 80], [100, 75]] + } +} + + +class RandomColor(object): + + def __init__(self, seed=None): + self.colormap = COLORMAP + self.random = random.Random(seed) + + for color_name, color_attrs in self.colormap.items(): + lower_bounds = color_attrs['lower_bounds'] + s_min = lower_bounds[0][0] + s_max = lower_bounds[len(lower_bounds) - 1][0] + + b_min = lower_bounds[len(lower_bounds) - 1][1] + b_max = lower_bounds[0][1] + + self.colormap[color_name]['saturation_range'] = [s_min, s_max] + self.colormap[color_name]['brightness_range'] = [b_min, b_max] + + def generate(self, hue=None, luminosity=None, count=1, format_='hex'): + colors = [] + for _ in range(count): + # First we pick a hue (H) + H = self.pick_hue(hue) + + # Then use H to determine saturation (S) + S = self.pick_saturation(H, hue, luminosity) + + # Then use S and H to determine brightness (B). + B = self.pick_brightness(H, S, luminosity) + + # Then we return the HSB color in the desired format + colors.append(self.set_format([H, S, B], format_)) + + return colors + + def pick_hue(self, hue): + hue_range = self.get_hue_range(hue) + hue = self.random_within(hue_range) + + # Instead of storing red as two seperate ranges, + # we group them, using negative numbers + if (hue < 0): + hue += 360 + + return hue + + def pick_saturation(self, hue, hue_name, luminosity): + + if luminosity == 'random': + return self.random_within([0, 100]) + + if hue_name == 'monochrome': + return 0 + + saturation_range = self.get_saturation_range(hue) + + s_min = saturation_range[0] + s_max = saturation_range[1] + + if luminosity == 'bright': + s_min = 55 + elif luminosity == 'dark': + s_min = s_max - 10 + elif luminosity == 'light': + s_max = 55 + + return self.random_within([s_min, s_max]) + + def pick_brightness(self, H, S, luminosity): + b_min = self.get_minimum_brightness(H, S) + b_max = 100 + + if luminosity == 'dark': + b_max = b_min + 20 + elif luminosity == 'light': + b_min = (b_max + b_min) / 2 + elif luminosity == 'random': + b_min = 0 + b_max = 100 + + return self.random_within([b_min, b_max]) + + def set_format(self, hsv, format_): + if 'hsv' in format_: + color = hsv + elif 'rgb' in format_: + color = self.hsv_to_rgb(hsv) + elif 'hex' in format_: + r, g, b = self.hsv_to_rgb(hsv) + return '#%02x%02x%02x' % (r, g, b) + else: + return 'unrecognized format' + + if 'Array' in format_ or format_ == 'hex': + return color + else: + prefix = format_[:3] + color_values = [str(x) for x in color] + return '%s(%s)' % (prefix, ', '.join(color_values)) + + def get_minimum_brightness(self, H, S): + lower_bounds = self.get_color_info(H)['lower_bounds'] + + for i in range(len(lower_bounds) - 1): + s1 = lower_bounds[i][0] + v1 = lower_bounds[i][1] + + s2 = lower_bounds[i + 1][0] + v2 = lower_bounds[i + 1][1] + + if s1 <= S <= s2: + m = (v2 - v1) / (s2 - s1) + b = v1 - m * s1 + + return m * S + b + + return 0 + + def get_hue_range(self, color_input): + if color_input and color_input.isdigit(): + number = int(color_input) + + if 0 < number < 360: + return [number, number] + + elif color_input and color_input in self.colormap: + color = self.colormap[color_input] + if 'hue_range' in color: + return color['hue_range'] + + else: + return [0, 360] + + def get_saturation_range(self, hue): + return self.get_color_info(hue)['saturation_range'] + + def get_color_info(self, hue): + # Maps red colors to make picking hue easier + if 334 <= hue <= 360: + hue -= 360 + + for color_name, color in self.colormap.items(): + if color['hue_range'] and color['hue_range'][0] <= hue <= color[ + 'hue_range'][1]: + return self.colormap[color_name] + + # this should probably raise an exception + return 'Color not found' + + def random_within(self, r): + return self.random.randint(int(r[0]), int(r[1])) + + @classmethod + def hsv_to_rgb(cls, hsv): + h, s, v = hsv + h = 1 if h == 0 else h + h = 359 if h == 360 else h + + h = float(h) / 360 + s = float(s) / 100 + v = float(v) / 100 + + rgb = colorsys.hsv_to_rgb(h, s, v) + return [int(c * 255) for c in rgb] + + +def rand_color(): + generator = RandomColor() + hue = random.choice(list(COLORMAP.keys())) + color = generator.generate(hue=hue, count=1, format_='rgb')[0] + color = color[color.find('(') + 1:color.find(')')] + color = tuple([int(u) for u in color.split(',')]) + return color diff --git a/modelscope/models/cv/image_to_image_translation/ops/random_mask.py b/modelscope/models/cv/image_to_image_translation/ops/random_mask.py new file mode 100644 index 00000000..a6b55916 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/random_mask.py @@ -0,0 +1,79 @@ +import cv2 +import numpy as np + +__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop'] + + +def make_irregular_mask(w, + h, + max_angle=4, + max_length=200, + max_width=100, + min_strokes=1, + max_strokes=5, + mode='line'): + # initialize mask + assert mode in ['line', 'circle', 'square'] + mask = np.zeros((h, w), np.float32) + + # draw strokes + num_strokes = np.random.randint(min_strokes, max_strokes + 1) + for i in range(num_strokes): + x1 = np.random.randint(w) + y1 = np.random.randint(h) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_length) + radius = 5 + np.random.randint(max_width) + x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w) + y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h) + if mode == 'line': + cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius) + elif mode == 'circle': + cv2.circle( + mask, (x1, y1), radius=radius, color=1.0, thickness=-1) + elif mode == 'square': + radius = radius // 2 + mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1 + x1, y1 = x2, y2 + return mask + + +def make_rectangle_mask(w, + h, + margin=10, + min_size=30, + max_size=150, + min_strokes=1, + max_strokes=4): + # initialize mask + mask = np.zeros((h, w), np.float32) + + # draw rectangles + num_strokes = np.random.randint(min_strokes, max_strokes + 1) + for i in range(num_strokes): + box_w = np.random.randint(min_size, max_size) + box_h = np.random.randint(min_size, max_size) + x1 = np.random.randint(margin, w - margin - box_w + 1) + y1 = np.random.randint(margin, h - margin - box_h + 1) + mask[y1:y1 + box_h, x1:x1 + box_w] = 1 + return mask + + +def make_uncrop(w, h): + # initialize mask + mask = np.zeros((h, w), np.float32) + + # randomly halve the image + side = np.random.choice([0, 1, 2, 3]) + if side == 0: + mask[:h // 2, :] = 1 + elif side == 1: + mask[h // 2:, :] = 1 + elif side == 2: + mask[:, :w // 2] = 1 + elif side == 2: + mask[:, w // 2:] = 1 + return mask diff --git a/modelscope/models/cv/image_to_image_translation/ops/svd.py b/modelscope/models/cv/image_to_image_translation/ops/svd.py new file mode 100644 index 00000000..c5173de1 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/svd.py @@ -0,0 +1,152 @@ +r"""SVD of linear degradation matrices described in the paper + ``Denoising Diffusion Restoration Models.'' + @article{kawar2022denoising, + title={Denoising Diffusion Restoration Models}, + author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song}, + year={2022}, + journal={arXiv preprint arXiv:2201.11793}, + } +""" +import torch + +__all__ = ['SVD', 'IdentitySVD', 'DenoiseSVD', 'ColorizationSVD'] + + +class SVD(object): + r"""SVD decomposition of a matrix, i.e., H = UDV^T. + NOTE: assume that all inputs (i.e., h, x) are of shape [B, CHW]. + """ + + def __init__(self, h): + self.u, self.d, self.v = torch.svd(h, some=False) + self.ut = self.u.t() + self.vt = self.v.t() + self.d[self.d < 1e-3] = 0 + + def U(self, x): + return torch.matmul(self.u, x) + + def Ut(self, x): + return torch.matmul(self.ut, x) + + def V(self, x): + return torch.matmul(self.v, x) + + def Vt(self, x): + return torch.matmul(self.vt, x) + + @property + def D(self): + return self.d + + def H(self, x): + return self.U(self.D * self.Vt(x)[:, :self.D.size(0)]) + + def Ht(self, x): + return self.V(self._pad(self.D * self.Ut(x)[:, :self.D.size(0)])) + + def Hinv(self, x): + r"""Multiplies x by the pseudo inverse of H. + """ + x = self.Ut(x) + x[:, :self.D.size(0)] = x[:, :self.D.size(0)] / self.D + return self.V(self._pad(x)) + + def _pad(self, x): + o = x.new_zeros(x.size(0), self.v.size(0)) + o[:, :self.u.size(0)] = x.view(x.size(0), -1) + return o + + def to(self, *args, **kwargs): + r"""Update the data type and device of UDV matrices. + """ + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + setattr(self, k, v.to(*args, **kwargs)) + return self + + +class IdentitySVD(SVD): + + def __init__(self, c, h, w): + self.d = torch.ones(c * h * w) + + def U(self, x): + return x.clone() + + def Ut(self, x): + return x.clone() + + def V(self, x): + return x.clone() + + def Vt(self, x): + return x.clone() + + def H(self, x): + return x.clone() + + def Ht(self, x): + return x.clone() + + def Hinv(self, x): + return x.clone() + + def _pad(self, x): + return x.clone() + + +class DenoiseSVD(SVD): + + def __init__(self, c, h, w): + self.num_entries = c * h * w + self.d = torch.ones(self.num_entries) + + def U(self, x): + return x.clone() + + def Ut(self, x): + return x.clone() + + def V(self, x): + return x.clone() + + def Vt(self, x): + return x.clone() + + def _pad(self, x): + return x.clone() + + +class ColorizationSVD(SVD): + + def __init__(self, c, h, w): + self.color_dim = c + self.num_pixels = h * w + self.u, self.d, self.v = torch.svd(torch.ones(1, c) / c, some=False) + self.vt = self.v.t() + + def U(self, x): + return self.u[0, 0] * x + + def Ut(self, x): + return self.u[0, 0] * x + + def V(self, x): + return torch.einsum('ij,bjn->bin', self.v, + x.view(x.size(0), self.color_dim, + self.num_pixels)).flatten(1) + + def Vt(self, x): + return torch.einsum('ij,bjn->bin', self.vt, + x.view(x.size(0), self.color_dim, + self.num_pixels)).flatten(1) + + @property + def D(self): + return self.d.repeat(self.num_pixels) + + def _pad(self, x): + o = x.new_zeros(x.size(0), self.color_dim * self.num_pixels) + o[:, :self.num_pixels] = x + return o diff --git a/modelscope/models/cv/image_to_image_translation/ops/utils.py b/modelscope/models/cv/image_to_image_translation/ops/utils.py new file mode 100644 index 00000000..3e523f4c --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/utils.py @@ -0,0 +1,224 @@ +import base64 +import binascii +import hashlib +import math +import os +import os.path as osp +import zipfile +from io import BytesIO +from multiprocessing.pool import ThreadPool as Pool + +import cv2 +import json +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +from .random_color import rand_color + +__all__ = [ + 'ceil_divide', 'to_device', 'rand_name', 'ema', 'parallel', 'unzip', + 'load_state_dict', 'inverse_indices', 'detect_duplicates', 'md5', 'rope', + 'format_state', 'breakup_grid', 'viz_anno_geometry', 'image_to_base64' +] + +TFS_CLIENT = None + + +def ceil_divide(a, b): + return int(math.ceil(a / b)) + + +def to_device(batch, device, non_blocking=False): + if isinstance(batch, (list, tuple)): + return type(batch)([to_device(u, device, non_blocking) for u in batch]) + elif isinstance(batch, dict): + return type(batch)([(k, to_device(v, device, non_blocking)) + for k, v in batch.items()]) + elif isinstance(batch, torch.Tensor): + return batch.to(device, non_blocking=non_blocking) + return batch + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +@torch.no_grad() +def ema(net_ema, net, beta, copy_buffer=False): + assert 0.0 <= beta <= 1.0 + for p_ema, p in zip(net_ema.parameters(), net.parameters()): + p_ema.copy_(p.lerp(p_ema, beta)) + if copy_buffer: + for b_ema, b in zip(net_ema.buffers(), net.buffers()): + b_ema.copy_(b) + + +def parallel(func, args_list, num_workers=32, timeout=None): + assert isinstance(args_list, list) + if not isinstance(args_list[0], tuple): + args_list = [(args, ) for args in args_list] + if num_workers == 0: + return [func(*args) for args in args_list] + with Pool(processes=num_workers) as pool: + results = [pool.apply_async(func, args) for args in args_list] + results = [res.get(timeout=timeout) for res in results] + return results + + +def unzip(filename, dst_dir=None): + if dst_dir is None: + dst_dir = osp.dirname(filename) + with zipfile.ZipFile(filename, 'r') as zip_ref: + zip_ref.extractall(dst_dir) + + +def load_state_dict(module, state_dict, drop_prefix=''): + # find incompatible key-vals + src, dst = state_dict, module.state_dict() + if drop_prefix: + src = type(src)([ + (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v) + for k, v in src.items() + ]) + missing = [k for k in dst if k not in src] + unexpected = [k for k in src if k not in dst] + unmatched = [ + k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape + ] + + # keep only compatible key-vals + incompatible = set(unexpected + unmatched) + src = type(src)([(k, v) for k, v in src.items() if k not in incompatible]) + module.load_state_dict(src, strict=False) + + # report incompatible key-vals + if len(missing) != 0: + print(' Missing: ' + ', '.join(missing), flush=True) + if len(unexpected) != 0: + print(' Unexpected: ' + ', '.join(unexpected), flush=True) + if len(unmatched) != 0: + print(' Shape unmatched: ' + ', '.join(unmatched), flush=True) + + +def inverse_indices(indices): + r"""Inverse map of indices. + E.g., if A[indices] == B, then B[inv_indices] == A. + """ + inv_indices = torch.empty_like(indices) + inv_indices[indices] = torch.arange(len(indices)).to(indices) + return inv_indices + + +def detect_duplicates(feats, thr=0.9): + assert feats.ndim == 2 + + # compute simmat + feats = F.normalize(feats, p=2, dim=1) + simmat = torch.mm(feats, feats.T) + simmat.triu_(1) + torch.cuda.synchronize() + + # detect duplicates + mask = ~simmat.gt(thr).any(dim=0) + return torch.where(mask)[0] + + +def md5(filename): + with open(filename, 'rb') as f: + return hashlib.md5(f.read()).hexdigest() + + +def rope(x): + r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C]. + """ + # reshape + shape = x.shape + x = x.view(x.size(0), -1, x.size(-1)) + l, c = x.shape[-2:] + assert c % 2 == 0 + half = c // 2 + + # apply rotary position embedding on x + sinusoid = torch.outer( + torch.arange(l).to(x), + torch.pow(10000, -torch.arange(half).to(x).div(half))) + sin, cos = torch.sin(sinusoid), torch.cos(sinusoid) + x1, x2 = x.chunk(2, dim=-1) + x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + # reshape back + return x.view(shape) + + +def format_state(state, filename=None): + r"""For comparing/aligning state_dict. + """ + content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()]) + if filename: + with open(filename, 'w') as f: + f.write(content) + + +def breakup_grid(img, grid_size): + r"""The inverse operator of ``torchvision.utils.make_grid``. + """ + # params + nrow = img.height // grid_size + ncol = img.width // grid_size + wrow = wcol = 2 # NOTE: use default values here + + # collect grids + grids = [] + for i in range(nrow): + for j in range(ncol): + x1 = j * grid_size + (j + 1) * wcol + y1 = i * grid_size + (i + 1) * wrow + grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size))) + return grids + + +def viz_anno_geometry(item): + r"""Visualize an annotation item from SmartLabel. + """ + if isinstance(item, str): + item = json.loads(item) + assert isinstance(item, dict) + + # read image + orig_img = read_image(item['image_url'], retry=100) + img = cv2.cvtColor(np.asarray(orig_img), cv2.COLOR_BGR2RGB) + + # loop over geometries + for geometry in item['sd_result']['items']: + # params + poly_img = img.copy() + color = rand_color() + points = np.array(geometry['meta']['geometry']).round().astype(int) + line_color = tuple([int(u * 0.55) for u in color]) + + # draw polygons + poly_img = cv2.fillPoly(poly_img, pts=[points], color=color) + poly_img = cv2.polylines( + poly_img, + pts=[points], + isClosed=True, + color=line_color, + thickness=2) + + # mixing + img = np.clip(0.25 * img + 0.75 * poly_img, 0, 255).astype(np.uint8) + return orig_img, Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + +def image_to_base64(img, format='JPEG'): + buffer = BytesIO() + img.save(buffer, format=format) + code = base64.b64encode(buffer.getvalue()).decode('utf-8') + return code diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 6dd0b794..d183c889 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from .image_color_enhance_pipeline import ImageColorEnhancePipeline from .image_colorization_pipeline import ImageColorizationPipeline from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline + from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline from .video_category_pipeline import VideoCategoryPipeline from .image_matting_pipeline import ImageMattingPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline diff --git a/modelscope/pipelines/cv/image_to_image_translation_pipeline.py b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py new file mode 100644 index 00000000..a9f83e02 --- /dev/null +++ b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py @@ -0,0 +1,325 @@ +import io +import os.path as osp +import sys +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.utils import save_image + +import modelscope.models.cv.image_to_image_translation.data as data +import modelscope.models.cv.image_to_image_translation.models as models +import modelscope.models.cv.image_to_image_translation.ops as ops +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_to_image_translation.model_translation import \ + UNet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def save_grid(imgs, filename, nrow=5): + save_image( + imgs.clamp(-1, 1), filename, range=(-1, 1), normalize=True, nrow=nrow) + + +@PIPELINES.register_module( + Tasks.image_generation, module_name=Pipelines.image2image_translation) +class Image2ImageTranslationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.repetition = 4 + # load autoencoder model + ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path) + logger.info(f'loading autoencoder model from {ae_model_path}') + self.autoencoder = models.VQAutoencoder( + dim=self.cfg.Params.ae.ae_dim, + z_dim=self.cfg.Params.ae.ae_z_dim, + dim_mult=self.cfg.Params.ae.ae_dim_mult, + attn_scales=self.cfg.Params.ae.ae_attn_scales, + codebook_size=self.cfg.Params.ae.ae_codebook_size).eval( + ).requires_grad_(False).to(self._device) # noqa E123 + self.autoencoder.load_state_dict( + torch.load(ae_model_path, map_location=self._device)) + logger.info('load autoencoder model done') + + # load palette model + palette_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading palette model from {palette_model_path}') + self.palette = UNet( + resolution=self.cfg.Params.unet.unet_resolution, + in_dim=self.cfg.Params.unet.unet_in_dim, + dim=self.cfg.Params.unet.unet_dim, + context_dim=self.cfg.Params.unet.unet_context_dim, + out_dim=self.cfg.Params.unet.unet_out_dim, + dim_mult=self.cfg.Params.unet.unet_dim_mult, + num_heads=self.cfg.Params.unet.unet_num_heads, + head_dim=None, + num_res_blocks=self.cfg.Params.unet.unet_res_blocks, + attn_scales=self.cfg.Params.unet.unet_attn_scales, + num_classes=self.cfg.Params.unet.unet_num_classes + 1, + dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_( + False).to(self._device) + self.palette.load_state_dict( + torch.load(palette_model_path, map_location=self._device)) + logger.info('load palette model done') + + # diffusion + logger.info('Initialization diffusion ...') + betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule, + self.cfg.Params.diffusion.num_timesteps) + self.diffusion = ops.GaussianDiffusion( + betas=betas, + mean_type=self.cfg.Params.diffusion.mean_type, + var_type=self.cfg.Params.diffusion.var_type, + loss_type=self.cfg.Params.diffusion.loss_type, + rescale_timesteps=False) + + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize( + self.cfg.DATA.scale_size, + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + if len(input) == 3: # colorization + _, input_type, save_path = input + elif len(input) == 4: # uncropping or in-painting + _, meta, input_type, save_path = input + if input_type == 0: # uncropping + assert meta in ['up', 'down', 'left', 'right'] + direction = meta + + list_ = [] + for i in range(len(input) - 2): + input_img = input[i] + if input_img in ['up', 'down', 'left', 'right']: + continue + if isinstance(input_img, str): + if input_type == 2 and i == 0: + logger.info('Loading image by origin way ... ') + bytes = File.read(input_img) + img = Image.open(io.BytesIO(bytes)) + assert len(img.split()) == 4 + else: + img = load_image(input_img) + elif isinstance(input_img, PIL.Image.Image): + img = input_img.convert('RGB') + elif isinstance(input_img, np.ndarray): + if len(input_img.shape) == 2: + input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR) + img = input_img[:, :, ::-1] + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + list_.append(img) + img_list = [] + if input_type != 2: + for img in list_: + img = self.transforms(img) + imgs = torch.unsqueeze(img, 0) + imgs = imgs.to(self._device) + img_list.append(imgs) + elif input_type == 2: + mask, masked_img = list_[0], list_[1] + img = self.transforms(masked_img.convert('RGB')) + mask = torch.from_numpy( + np.array( + mask.resize((img.shape[2], img.shape[1])), + dtype=np.float32)[:, :, -1] / 255.0).unsqueeze(0) + img = (1 - mask) * img + mask * torch.randn_like(img).clamp_(-1, 1) + imgs = img.unsqueeze(0).to(self._device) + b, c, h, w = imgs.shape + y = torch.LongTensor([self.cfg.Classes.class_id]).to(self._device) + + if input_type == 0: + assert len(img_list) == 1 + result = { + 'image_data': img_list[0], + 'c': c, + 'h': h, + 'w': w, + 'direction': direction, + 'type': input_type, + 'y': y, + 'save_path': save_path + } + elif input_type == 1: + assert len(img_list) == 1 + result = { + 'image_data': img_list[0], + 'c': c, + 'h': h, + 'w': w, + 'type': input_type, + 'y': y, + 'save_path': save_path + } + elif input_type == 2: + result = { + 'image_data': imgs, + # 'image_mask': mask, + 'c': c, + 'h': h, + 'w': w, + 'type': input_type, + 'y': y, + 'save_path': save_path + } + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + type_ = input['type'] + if type_ == 0: + # Uncropping + img = input['image_data'] + direction = input['direction'] + y = input['y'] + + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + + logger.info(f'Processing {direction} uncropping') + img = img.clone() + i_y = y.repeat(self.repetition, 1) + if direction == 'up': + img[:, :, input['h'] // 2:, :] = torch.randn_like( + img[:, :, input['h'] // 2:, :]) + elif direction == 'down': + img[:, :, :input['h'] // 2, :] = torch.randn_like( + img[:, :, :input['h'] // 2, :]) + elif direction == 'left': + img[:, :, :, + input['w'] // 2:] = torch.randn_like(img[:, :, :, + input['w'] // 2:]) + elif direction == 'right': + img[:, :, :, :input['w'] // 2] = torch.randn_like( + img[:, :, :, :input['w'] // 2]) + i_concat = self.autoencoder.encode(img).repeat( + self.repetition, 1, 1, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn_like(i_concat), + model=self.palette, + model_kwargs=[{ + 'y': i_y, + 'concat': i_concat + }, { + 'y': + torch.full_like(i_y, + self.cfg.Params.unet.unet_num_classes), + 'concat': + i_concat + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=1.0) + i_gen_imgs = self.autoencoder.decode(x0) + save_grid(i_gen_imgs, input['save_path'], nrow=4) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + + elif type_ == 1: + # Colorization # + img = input['image_data'] + y = input['y'] + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + + logger.info('Processing Colorization') + img = img.clone() + img = img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) + i_concat = self.autoencoder.encode(img).repeat( + self.repetition, 1, 1, 1) + i_y = y.repeat(self.repetition, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn_like(i_concat), + model=self.palette, + model_kwargs=[{ + 'y': i_y, + 'concat': i_concat + }, { + 'y': + torch.full_like(i_y, + self.cfg.Params.unet.unet_num_classes), + 'concat': + i_concat + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=0.0) + i_gen_imgs = self.autoencoder.decode(x0) + save_grid(i_gen_imgs, input['save_path'], nrow=4) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + elif type_ == 2: + # Combination # + logger.info('Processing Combination') + + # prepare inputs + img = input['image_data'] + concat = self.autoencoder.encode(img).repeat( + self.repetition, 1, 1, 1) + y = torch.LongTensor([126]).unsqueeze(0).to(self._device).repeat( + self.repetition, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn_like(concat), + model=self.palette, + model_kwargs=[{ + 'y': y, + 'concat': concat + }, { + 'y': + torch.full_like(y, self.cfg.Params.unet.unet_num_classes), + 'concat': + concat + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=1.0) + i_gen_imgs = self.autoencoder.decode(x0) + save_grid(i_gen_imgs, input['save_path'], nrow=4) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + else: + raise TypeError( + f'input type should be 0 (Uncropping), 1 (Colorization), 2 (Combation)' + f' but got {type_}') + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_image2image_translation.py b/tests/pipelines/test_image2image_translation.py new file mode 100644 index 00000000..24766d25 --- /dev/null +++ b/tests/pipelines/test_image2image_translation.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import shutil +import unittest + +from modelscope.fileio import File +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class Image2ImageTranslationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + r"""We provide three translation modes, i.e., uncropping, colorization and combination. + You can pass the following parameters for different mode. + 1. Uncropping Mode: + result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 'left', 0, 'result.jpg')) + 2. Colorization Mode: + result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 1, 'result.jpg')) + 3. Combination Mode: + just like the following code. + """ + img2img_gen_pipeline = pipeline( + Tasks.image_generation, + model='damo/cv_latent_diffusion_image2image_translation') + result = img2img_gen_pipeline( + ('data/test/images/img2img_input_mask.png', + 'data/test/images/img2img_input_masked_img.png', 2, + 'result.jpg')) # combination mode + + print(f'output: {result}.') + + +if __name__ == '__main__': + unittest.main()