tianxi.tl yingda.chen 3 years ago
parent
commit
0424f3c510
23 changed files with 4830 additions and 0 deletions
  1. +3
    -0
      data/test/images/img2img_input.jpg
  2. +3
    -0
      data/test/images/img2img_input_mask.png
  3. +3
    -0
      data/test/images/img2img_input_masked_img.png
  4. +1
    -0
      modelscope/metainfo.py
  5. +1
    -0
      modelscope/models/cv/image_to_image_translation/data/__init__.py
  6. +121
    -0
      modelscope/models/cv/image_to_image_translation/data/transforms.py
  7. +323
    -0
      modelscope/models/cv/image_to_image_translation/model_translation.py
  8. +2
    -0
      modelscope/models/cv/image_to_image_translation/models/__init__.py
  9. +412
    -0
      modelscope/models/cv/image_to_image_translation/models/autoencoder.py
  10. +418
    -0
      modelscope/models/cv/image_to_image_translation/models/clip.py
  11. +8
    -0
      modelscope/models/cv/image_to_image_translation/ops/__init__.py
  12. +663
    -0
      modelscope/models/cv/image_to_image_translation/ops/apps.py
  13. +1074
    -0
      modelscope/models/cv/image_to_image_translation/ops/degradation.py
  14. +598
    -0
      modelscope/models/cv/image_to_image_translation/ops/diffusion.py
  15. +35
    -0
      modelscope/models/cv/image_to_image_translation/ops/losses.py
  16. +126
    -0
      modelscope/models/cv/image_to_image_translation/ops/metrics.py
  17. +220
    -0
      modelscope/models/cv/image_to_image_translation/ops/random_color.py
  18. +79
    -0
      modelscope/models/cv/image_to_image_translation/ops/random_mask.py
  19. +152
    -0
      modelscope/models/cv/image_to_image_translation/ops/svd.py
  20. +224
    -0
      modelscope/models/cv/image_to_image_translation/ops/utils.py
  21. +1
    -0
      modelscope/pipelines/cv/__init__.py
  22. +325
    -0
      modelscope/pipelines/cv/image_to_image_translation_pipeline.py
  23. +38
    -0
      tests/pipelines/test_image2image_translation.py

+ 3
- 0
data/test/images/img2img_input.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e4cbf844cd16a892a7d2f2764b1537c346675d3b0145016d6836441ba907366
size 9195

+ 3
- 0
data/test/images/img2img_input_mask.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33b3d3076e191fa92511bf69fa76e1222b3b3be0049e711c948a1218b587510c
size 4805

+ 3
- 0
data/test/images/img2img_input_masked_img.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:99c2b02a927b86ff194287ea4c5a05349dd800cff2b523212d1dad378c252feb
size 103334

+ 1
- 0
modelscope/metainfo.py View File

@@ -77,6 +77,7 @@ class Pipelines(object):
face_image_generation = 'gan-face-image-generation'
style_transfer = 'AAMS-style-transfer'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'
live_category = 'live-category'
video_category = 'video-category'



+ 1
- 0
modelscope/models/cv/image_to_image_translation/data/__init__.py View File

@@ -0,0 +1 @@
from .transforms import * # noqa F403

+ 121
- 0
modelscope/models/cv/image_to_image_translation/data/transforms.py View File

@@ -0,0 +1,121 @@
import math
import random

import torchvision.transforms.functional as TF
from PIL import Image, ImageFilter

__all__ = [
'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate',
'RandomGaussianBlur', 'RandomCrop'
]


class Identity(object):

def __call__(self, *args):
if len(args) == 0:
return None
elif len(args) == 1:
return args[0]
else:
return args


class PadToSquare(object):

def __init__(self, fill=(255, 255, 255)):
self.fill = fill

def __call__(self, img):
w, h = img.size
if w != h:
if w > h:
t = (w - h) // 2
b = w - h - t
padding = (0, t, 0, b)
else:
left = (h - w) // 2
right = h - w - l
padding = (left, 0, right, 0)
img = TF.pad(img, padding, fill=self.fill)
return img


class RandomScale(object):

def __init__(self,
min_scale=0.5,
max_scale=2.0,
min_ratio=0.8,
max_ratio=1.25):
self.min_scale = min_scale
self.max_scale = max_scale
self.min_ratio = min_ratio
self.max_ratio = max_ratio

def __call__(self, img):
w, h = img.size
scale = 2**random.uniform(
math.log2(self.min_scale), math.log2(self.max_scale))
ratio = 2**random.uniform(
math.log2(self.min_ratio), math.log2(self.max_ratio))
ow = int(w * scale * math.sqrt(ratio))
oh = int(h * scale / math.sqrt(ratio))
img = img.resize((ow, oh), Image.BILINEAR)
return img


class RandomRotate(object):

def __init__(self,
min_angle=-10.0,
max_angle=10.0,
padding=(255, 255, 255),
p=0.5):
self.min_angle = min_angle
self.max_angle = max_angle
self.padding = padding
self.p = p

def __call__(self, img):
if random.random() < self.p:
angle = random.uniform(self.min_angle, self.max_angle)
img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding)
return img


class RandomGaussianBlur(object):

def __init__(self, radius=5, p=0.5):
self.radius = radius
self.p = p

def __call__(self, img):
if random.random() < self.p:
img = img.filter(ImageFilter.GaussianBlur(radius=self.radius))
return img


class RandomCrop(object):

def __init__(self, size, padding=(255, 255, 255)):
self.size = size
self.padding = padding

def __call__(self, img):
# pad
w, h = img.size
pad_w = max(0, self.size - w)
pad_h = max(0, self.size - h)
if pad_w > 0 or pad_h > 0:
half_w = pad_w // 2
half_h = pad_h // 2
pad = (half_w, half_h, pad_w - half_w, pad_h - half_h)
img = TF.pad(img, pad, fill=self.padding)

# crop
w, h = img.size
x1 = random.randint(0, w - self.size)
y1 = random.randint(0, h - self.size)
img = img.crop((x1, y1, x1 + self.size, y1 + self.size))
return img

+ 323
- 0
modelscope/models/cv/image_to_image_translation/model_translation.py View File

@@ -0,0 +1,323 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['UNet']


def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()

# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x


class Resample(nn.Module):

def __init__(self, scale_factor=1.0):
assert scale_factor in [0.5, 1.0, 2.0]
super(Resample, self).__init__()
self.scale_factor = scale_factor

def forward(self, x):
if self.scale_factor == 2.0:
x = F.interpolate(x, scale_factor=2, mode='nearest')
elif self.scale_factor == 0.5:
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x


class ResidualBlock(nn.Module):

def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim

# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.embedding = nn.Sequential(nn.SiLU(),
nn.Linear(embed_dim, out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)

# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)

def forward(self, x, y):
identity = x
x = self.layer1(x)
x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1)
x = self.layer2(x)
x = x + self.shortcut(identity)
return x


class MultiHeadAttention(nn.Module):

def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0):
assert dim % num_heads == 0
assert context_dim is None or context_dim % num_heads == 0
context_dim = context_dim or dim
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = math.pow(self.head_dim, -0.25)

# layers
self.q = nn.Linear(dim, dim, bias=False)
self.k = nn.Linear(context_dim, dim, bias=False)
self.v = nn.Linear(context_dim, dim, bias=False)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, context=None):
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim

# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)

# compute attention
attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)

# gather context
x = torch.einsum('bnij,bjnc->binc', attn, v)
x = x.reshape(b, -1, n * c)

# output
x = self.o(x)
x = self.dropout(x)
return x


class GLU(nn.Module):

def __init__(self, in_dim, out_dim):
super(GLU, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.proj = nn.Linear(in_dim, out_dim * 2)

def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)


class TransformerBlock(nn.Module):

def __init__(self, dim, context_dim, num_heads, dropout=0.0):
super(TransformerBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads

# input
self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(dim, dim, 1)

# self attention
self.norm2 = nn.LayerNorm(dim)
self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout)

# cross attention
self.norm3 = nn.LayerNorm(dim)
self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads,
dropout)

# ffn
self.norm4 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim))

# output
self.conv2 = nn.Conv2d(dim, dim, 1)

# zero out the last layer params
nn.init.zeros_(self.conv2.weight)

def forward(self, x, context):
b, c, h, w = x.size()
identity = x

# input
x = self.norm1(x)
x = self.conv1(x).view(b, c, -1).transpose(1, 2)

# attention
x = x + self.self_attn(self.norm2(x))
x = x + self.cross_attn(self.norm3(x), context)
x = x + self.ffn(self.norm4(x))

# output
x = x.transpose(1, 2).view(b, c, h, w)
x = self.conv2(x)
return x + identity


class UNet(nn.Module):

def __init__(self,
resolution=64,
in_dim=3,
dim=192,
context_dim=512,
out_dim=3,
dim_mult=[1, 2, 3, 5],
num_heads=1,
head_dim=None,
num_res_blocks=2,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
num_classes=1001,
dropout=0.0):
embed_dim = dim * 4
super(UNet, self).__init__()
self.resolution = resolution
self.in_dim = in_dim
self.dim = dim
self.context_dim = context_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.num_classes = num_classes

# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0

# embeddings
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.label_embedding = nn.Embedding(num_classes, context_dim)

# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
block = nn.ModuleList(
[ResidualBlock(in_dim, embed_dim, out_dim, dropout)])
if scale in attn_scales:
block.append(
TransformerBlock(out_dim, context_dim, num_heads))
in_dim = out_dim
self.encoder.append(block)
shortcut_dims.append(out_dim)

# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
self.encoder.append(
nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1))
shortcut_dims.append(out_dim)
scale /= 2.0

# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, dropout),
TransformerBlock(out_dim, context_dim, num_heads),
ResidualBlock(out_dim, embed_dim, out_dim, dropout)
])

# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, dropout)
])
if scale in attn_scales:
block.append(
TransformerBlock(out_dim, context_dim, num_heads,
dropout))
in_dim = out_dim

# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
block.append(
nn.Sequential(
Resample(scale_factor=2.0),
nn.Conv2d(out_dim, out_dim, 3, padding=1)))
scale *= 2.0
self.decoder.append(block)

# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))

# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)

def forward(self, x, t, y, concat=None):
# embeddings
if concat is not None:
x = torch.cat([x, concat], dim=1)
t = self.time_embedding(sinusoidal_embedding(t, self.dim))
y = self.label_embedding(y)

# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, t, y)
xs.append(x)

# middle
for block in self.middle:
x = self._forward_single(block, x, t, y)

# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, t, y)

# head
x = self.head(x)
return x

def _forward_single(self, module, x, t, y):
if isinstance(module, ResidualBlock):
x = module(x, t)
elif isinstance(module, TransformerBlock):
x = module(x, y)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, t, y)
else:
x = module(x)
return x

+ 2
- 0
modelscope/models/cv/image_to_image_translation/models/__init__.py View File

@@ -0,0 +1,2 @@
from .autoencoder import * # noqa F403
from .clip import * # noqa F403

+ 412
- 0
modelscope/models/cv/image_to_image_translation/models/autoencoder.py View File

@@ -0,0 +1,412 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator']


def group_norm(dim):
return nn.GroupNorm(32, dim, eps=1e-6, affine=True)


class Resample(nn.Module):

def __init__(self, dim, scale_factor):
super(Resample, self).__init__()
self.dim = dim
self.scale_factor = scale_factor

# layers
if scale_factor == 2.0:
self.resample = nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
nn.Conv2d(dim, dim, 3, padding=1))
elif scale_factor == 0.5:
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=2, padding=0))
else:
self.resample = nn.Identity()

def forward(self, x):
return self.resample(x)


class ResidualBlock(nn.Module):

def __init__(self, in_dim, out_dim, dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim

# layers
self.residual = nn.Sequential(
group_norm(in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim),
nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Conv2d(in_dim, out_dim,
1) if in_dim != out_dim else nn.Identity()

# zero out the last layer params
nn.init.zeros_(self.residual[-1].weight)

def forward(self, x):
return self.residual(x) + self.shortcut(x)


class AttentionBlock(nn.Module):

def __init__(self, dim):
super(AttentionBlock, self).__init__()
self.dim = dim
self.scale = math.pow(dim, -0.25)

# layers
self.norm = group_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)

# zero out the last layer params
nn.init.zeros_(self.proj.weight)

def forward(self, x):
identity = x
b, c, h, w = x.size()

# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1)

# compute attention
attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)

# gather context
x = torch.einsum('bij,bcj->bci', attn, v)
x = x.reshape(b, c, h, w)

# output
x = self.proj(x)
return x + identity


class Encoder(nn.Module):

def __init__(self,
dim=128,
z_dim=3,
dim_mult=[1, 2, 4],
num_res_blocks=2,
attn_scales=[],
dropout=0.0):
super(Encoder, self).__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales

# params
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0

# init block
self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1)

# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim

# downsample block
if i != len(dim_mult) - 1:
downsamples.append(Resample(out_dim, scale_factor=0.5))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)

# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))

# output blocks
self.head = nn.Sequential(
group_norm(out_dim), nn.SiLU(),
nn.Conv2d(out_dim, z_dim, 3, padding=1))

def forward(self, x):
x = self.conv1(x)
x = self.downsamples(x)
x = self.middle(x)
x = self.head(x)
return x


class Decoder(nn.Module):

def __init__(self,
dim=128,
z_dim=3,
dim_mult=[1, 2, 4],
num_res_blocks=2,
attn_scales=[],
dropout=0.0):
super(Decoder, self).__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales

# params
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)

# init block
self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1)

# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))

# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim

# upsample block
if i != len(dim_mult) - 1:
upsamples.append(Resample(out_dim, scale_factor=2.0))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)

# output blocks
self.head = nn.Sequential(
group_norm(out_dim), nn.SiLU(),
nn.Conv2d(out_dim, 3, 3, padding=1))

def forward(self, x):
x = self.conv1(x)
x = self.middle(x)
x = self.upsamples(x)
x = self.head(x)
return x


class VectorQuantizer(nn.Module):

def __init__(self, codebook_size=8192, z_dim=3, beta=0.25):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size
self.z_dim = z_dim
self.beta = beta

# init codebook
eps = math.sqrt(1.0 / codebook_size)
self.codebook = nn.Parameter(
torch.empty(codebook_size, z_dim).uniform_(-eps, eps))

def forward(self, z):
# preprocess
b, c, h, w = z.size()
flatten = z.permute(0, 2, 3, 1).reshape(-1, c)

# quantization
with torch.no_grad():
tokens = torch.cdist(flatten, self.codebook).argmin(dim=1)
quantized = F.embedding(tokens,
self.codebook).view(b, h, w,
c).permute(0, 3, 1, 2)

# compute loss
codebook_loss = F.mse_loss(quantized, z.detach())
commitment_loss = F.mse_loss(quantized.detach(), z)
loss = codebook_loss + self.beta * commitment_loss

# perplexity
counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype)
# dist.all_reduce(counts)
p = counts / counts.sum()
perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10)))

# postprocess
tokens = tokens.view(b, h, w)
quantized = z + (quantized - z).detach()
return quantized, tokens, loss, perplexity


class VQAutoencoder(nn.Module):

def __init__(self,
dim=128,
z_dim=3,
dim_mult=[1, 2, 4],
num_res_blocks=2,
attn_scales=[],
dropout=0.0,
codebook_size=8192,
beta=0.25):
super(VQAutoencoder, self).__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.codebook_size = codebook_size
self.beta = beta

# blocks
self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, dropout)
self.conv1 = nn.Conv2d(z_dim, z_dim, 1)
self.quantizer = VectorQuantizer(codebook_size, z_dim, beta)
self.conv2 = nn.Conv2d(z_dim, z_dim, 1)
self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, dropout)

def forward(self, x):
z = self.encoder(x)
z = self.conv1(z)
z, tokens, loss, perplexity = self.quantizer(z)
z = self.conv2(z)
x = self.decoder(z)
return x, tokens, loss, perplexity

def encode(self, imgs):
z = self.encoder(imgs)
z = self.conv1(z)
return z

def decode(self, z):
r"""Absort the quantizer in the decoder.
"""
z = self.quantizer(z)[0]
z = self.conv2(z)
imgs = self.decoder(z)
return imgs

@torch.no_grad()
def encode_to_tokens(self, imgs):
# preprocess
z = self.encoder(imgs)
z = self.conv1(z)

# quantization
b, c, h, w = z.size()
flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1)
return tokens.view(b, -1)

@torch.no_grad()
def decode_from_tokens(self, tokens):
# dequantization
z = F.embedding(tokens, self.quantizer.codebook)

# postprocess
b, l, c = z.size()
h = w = int(math.sqrt(l))
z = z.view(b, h, w, c).permute(0, 3, 1, 2)
z = self.conv2(z)
imgs = self.decoder(z)
return imgs


class KLAutoencoder(nn.Module):

def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
dropout=0.0):
super(KLAutoencoder, self).__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales

# blocks
self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, dropout)
self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1)
self.conv2 = nn.Conv2d(z_dim, z_dim, 1)
self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, dropout)

def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x = self.decode(z)
return x, mu, log_var

def encode(self, x):
x = self.encoder(x)
mu, log_var = self.conv1(x).chunk(2, dim=1)
return mu, log_var

def decode(self, z):
x = self.conv2(z)
x = self.decoder(x)
return x

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu


class PatchDiscriminator(nn.Module):

def __init__(self, in_dim=3, dim=64, num_layers=3):
super(PatchDiscriminator, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.num_layers = num_layers

# params
dims = [dim * min(8, 2**u) for u in range(num_layers + 1)]

# layers
layers = [
nn.Conv2d(in_dim, dim, 4, stride=2, padding=1),
nn.LeakyReLU(0.2)
]
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
stride = 1 if i == num_layers - 1 else 2
layers += [
nn.Conv2d(
in_dim, out_dim, 4, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.LeakyReLU(0.2)
]
layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)]
self.layers = nn.Sequential(*layers)

# initialize weights
self.apply(self.init_weights)

def forward(self, x):
return self.layers(x)

def init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)

+ 418
- 0
modelscope/models/cv/image_to_image_translation/models/clip.py View File

@@ -0,0 +1,418 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather

__all__ = [
'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14',
'clip_vit_l_14_336px', 'clip_vit_h_16'
]


def to_fp16(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
m.weight.data = m.weight.data.half()
if m.bias is not None:
m.bias.data = m.bias.data.half()
elif hasattr(m, 'head'):
p = getattr(m, 'head')
p.data = p.data.half()


class QuickGELU(nn.Module):

def forward(self, x):
return x * torch.sigmoid(1.702 * x)


class LayerNorm(nn.LayerNorm):
r"""Subclass of nn.LayerNorm to handle fp16.
"""

def forward(self, x):
return super(LayerNorm, self).forward(x.float()).type_as(x)


class SelfAttention(nn.Module):

def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
assert dim % num_heads == 0
super(SelfAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)

# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.attn_dropout = nn.Dropout(attn_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(proj_dropout)

def forward(self, x, mask=None):
r"""x: [B, L, C].
mask: [*, L, L].
"""
b, l, _, n = *x.size(), self.num_heads

# compute query, key, and value
q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
q = q.reshape(l, b * n, -1).transpose(0, 1)
k = k.reshape(l, b * n, -1).transpose(0, 1)
v = v.reshape(l, b * n, -1).transpose(0, 1)

# compute attention
attn = self.scale * torch.bmm(q, k.transpose(1, 2))
if mask is not None:
attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
attn = self.attn_dropout(attn)

# gather context
x = torch.bmm(attn, v)
x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)

# output
x = self.proj(x)
x = self.proj_dropout(x)
return x


class AttentionBlock(nn.Module):

def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
super(AttentionBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads

# layers
self.norm1 = LayerNorm(dim)
self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
nn.Dropout(proj_dropout))

def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.mlp(self.norm2(x))
return x


class VisionTransformer(nn.Module):

def __init__(self,
image_size=224,
patch_size=16,
dim=768,
out_dim=512,
num_heads=12,
num_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
assert image_size % patch_size == 0
super(VisionTransformer, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.num_patches = (image_size // patch_size)**2

# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(
gain * torch.randn(1, self.num_patches + 1, dim))
self.dropout = nn.Dropout(embedding_dropout)

# transformer
self.pre_norm = LayerNorm(dim)
self.transformer = nn.Sequential(*[
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim)

# head
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))

def forward(self, x):
b, dtype = x.size(0), self.head.dtype
x = x.type(dtype)

# patch-embedding
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
dim=1)
x = self.dropout(x + self.pos_embedding.type(dtype))
x = self.pre_norm(x)

# transformer
x = self.transformer(x)

# head
x = self.post_norm(x)
x = torch.mm(x[:, 0, :], self.head)
return x

def fp16(self):
return self.apply(to_fp16)


class TextTransformer(nn.Module):

def __init__(self,
vocab_size,
text_len,
dim=512,
out_dim=512,
num_heads=8,
num_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
super(TextTransformer, self).__init__()
self.vocab_size = vocab_size
self.text_len = text_len
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers

# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
self.dropout = nn.Dropout(embedding_dropout)

# transformer
self.transformer = nn.ModuleList([
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
for _ in range(num_layers)
])
self.norm = LayerNorm(dim)

# head
gain = 1.0 / math.sqrt(dim)
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))

# causal attention mask
self.register_buffer('attn_mask',
torch.tril(torch.ones(1, text_len, text_len)))

def forward(self, x):
eot, dtype = x.argmax(dim=-1), self.head.dtype

# embeddings
x = self.dropout(
self.token_embedding(x).type(dtype)
+ self.pos_embedding.type(dtype))

# transformer
for block in self.transformer:
x = block(x, self.attn_mask)

# head
x = self.norm(x)
x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
return x

def fp16(self):
return self.apply(to_fp16)


class CLIP(nn.Module):

def __init__(self,
embed_dim=512,
image_size=224,
patch_size=16,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=49408,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
super(CLIP, self).__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vocab_size = vocab_size
self.text_len = text_len
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers

# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout)
self.textual = TextTransformer(
vocab_size=vocab_size,
text_len=text_len,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))

def forward(self, imgs, txt_tokens):
r"""imgs: [B, C, H, W] of torch.float32.
txt_tokens: [B, T] of torch.long.
"""
xi = self.visual(imgs)
xt = self.textual(txt_tokens)

# normalize features
xi = F.normalize(xi, p=2, dim=1)
xt = F.normalize(xt, p=2, dim=1)

# gather features from all ranks
full_xi = ops.diff_all_gather(xi)
full_xt = ops.diff_all_gather(xt)

# logits
scale = self.log_scale.exp()
logits_i2t = scale * torch.mm(xi, full_xt.t())
logits_t2i = scale * torch.mm(xt, full_xi.t())

# labels
labels = torch.arange(
len(xi) * ops.get_rank(),
len(xi) * (ops.get_rank() + 1),
dtype=torch.long,
device=xi.device)
return logits_i2t, logits_t2i, labels

def init_weights(self):
# embeddings
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)

# attentions
for modality in ['visual', 'textual']:
dim = self.vision_dim if modality == 'visual' else 'textual'
transformer = getattr(self, modality).transformer
proj_gain = (1.0 / math.sqrt(dim)) * (
1.0 / math.sqrt(2 * transformer.num_layers))
attn_gain = 1.0 / math.sqrt(dim)
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
for block in transformer.layers:
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
nn.init.normal_(block.mlp[2].weight, std=proj_gain)

def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay':
0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups

def fp16(self):
return self.apply(to_fp16)


def clip_vit_b_32(**kwargs):
return CLIP(
embed_dim=512,
image_size=224,
patch_size=32,
vision_dim=768,
vision_heads=12,
vision_layers=12,
text_dim=512,
text_heads=8,
text_layers=12,
**kwargs)


def clip_vit_b_16(**kwargs):
return CLIP(
embed_dim=512,
image_size=224,
patch_size=16,
vision_dim=768,
vision_heads=12,
vision_layers=12,
text_dim=512,
text_heads=8,
text_layers=12,
**kwargs)


def clip_vit_l_14(**kwargs):
return CLIP(
embed_dim=768,
image_size=224,
patch_size=14,
vision_dim=1024,
vision_heads=16,
vision_layers=24,
text_dim=768,
text_heads=12,
text_layers=12,
**kwargs)


def clip_vit_l_14_336px(**kwargs):
return CLIP(
embed_dim=768,
image_size=336,
patch_size=14,
vision_dim=1024,
vision_heads=16,
vision_layers=24,
text_dim=768,
text_heads=12,
text_layers=12,
**kwargs)


def clip_vit_h_16(**kwargs):
return CLIP(
embed_dim=1024,
image_size=256,
patch_size=16,
vision_dim=1280,
vision_heads=16,
vision_layers=32,
text_dim=1024,
text_heads=16,
text_layers=24,
**kwargs)

+ 8
- 0
modelscope/models/cv/image_to_image_translation/ops/__init__.py View File

@@ -0,0 +1,8 @@
from .degradation import * # noqa F403
from .diffusion import * # noqa F403
from .losses import * # noqa F403
from .metrics import * # noqa F403
from .random_color import * # noqa F403
from .random_mask import * # noqa F403
from .svd import * # noqa F403
from .utils import * # noqa F403

+ 663
- 0
modelscope/models/cv/image_to_image_translation/ops/apps.py View File

@@ -0,0 +1,663 @@
# APPs that facilitate the use of pretrained neural networks.

import os.path as osp

import artist.data as data
import artist.models as models
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn.functional as F
import torchvision.transforms as T
from artist import DOWNLOAD_TO_CACHE
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from .utils import parallel, read_image

__all__ = [
'FeatureExtractor', 'Classifier', 'Text2Image', 'Sole2Shoe', 'ImageParser',
'TextImageMatch', 'taobao_feature_extractor', 'singleton_classifier',
'orientation_classifier', 'fashion_text2image', 'mindalle_text2image',
'sole2shoe', 'sole_parser', 'sod_foreground_parser',
'fashion_text_image_match'
]


class ImageFolder(Dataset):

def __init__(self, paths, transforms=None):
self.paths = paths
self.transforms = transforms

def __getitem__(self, index):
img = read_image(self.paths[index])
if img.mode != 'RGB':
img = img.convert('RGB')
if self.transforms is not None:
img = self.transforms(img)
return img

def __len__(self):
return len(self.paths)


class FeatureExtractor(object):

def __init__(
self,
model='InceptionV1',
checkpoint='models/inception-v1/1218shoes.v9_7.140.0.1520000',
resolution=224,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
batch_size=64,
device=torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
self.resolution = resolution
self.batch_size = batch_size
self.device = device

# init model
self.net = getattr(
models,
model)(num_classes=None).eval().requires_grad_(False).to(device)
self.net.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))

# data transforms
self.transforms = T.Compose([
data.PadToSquare(),
T.Resize(resolution),
T.ToTensor(),
T.Normalize(mean, std)
])

def __call__(self, imgs, num_workers=0):
r"""imgs: Either a PIL.Image or a list of PIL.Image instances.
"""
# preprocess
if isinstance(imgs, Image.Image):
imgs = [imgs]
assert isinstance(imgs,
(tuple, list)) and isinstance(imgs[0], Image.Image)
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)

# forward
feats = []
for batch in imgs.split(self.batch_size, dim=0):
batch = batch.to(self.device, non_blocking=True)
feats.append(self.net(batch))
return torch.cat(feats, dim=0)

def batch_process(self, paths):
# init dataloader
dataloader = DataLoader(
dataset=ImageFolder(paths, self.transforms),
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=8,
prefetch_factor=2)

# forward
feats = []
for step, batch in enumerate(dataloader, 1):
print(f'Step: {step}/{len(dataloader)}', flush=True)
batch = batch.to(self.device, non_blocking=True)
feats.append(self.net(batch))
return torch.cat(feats)


class Classifier(object):

def __init__(
self,
model='InceptionV1',
checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth',
num_classes=1,
resolution=224,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
batch_size=64,
device=torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
self.num_classes = num_classes
self.resolution = resolution
self.batch_size = batch_size
self.device = device

# init model
self.net = getattr(models, model)(
num_classes=num_classes).eval().requires_grad_(False).to(device)
self.net.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))

# data transforms
self.transforms = T.Compose([
data.PadToSquare(),
T.Resize(resolution),
T.ToTensor(),
T.Normalize(mean, std)
])

def __call__(self, imgs, num_workers=0):
r"""imgs: Either a PIL.Image or a list of PIL.Image instances.
"""
# preprocess
if isinstance(imgs, Image.Image):
imgs = [imgs]
assert isinstance(imgs,
(tuple, list)) and isinstance(imgs[0], Image.Image)
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)

# forward
scores = []
for batch in imgs.split(self.batch_size, dim=0):
batch = batch.to(self.device, non_blocking=True)
logits = self.net(batch)
scores.append(logits.sigmoid() if self.num_classes == # noqa W504
1 else logits.softmax(dim=1))
return torch.cat(scores, dim=0)


class Text2Image(object):

def __init__(
self,
vqgan_dim=128,
vqgan_z_dim=256,
vqgan_dim_mult=[1, 1, 2, 2, 4],
vqgan_num_res_blocks=2,
vqgan_attn_scales=[1.0 / 16],
vqgan_codebook_size=975,
vqgan_beta=0.25,
gpt_txt_vocab_size=21128,
gpt_txt_seq_len=64,
gpt_img_seq_len=1024,
gpt_dim=1024,
gpt_num_heads=16,
gpt_num_layers=24,
vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
gpt_checkpoint='models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth',
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64),
batch_size=16,
device=torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
self.tokenizer = tokenizer
self.batch_size = batch_size
self.device = device

# init VQGAN model
self.vqgan = models.VQGAN(
dim=vqgan_dim,
z_dim=vqgan_z_dim,
dim_mult=vqgan_dim_mult,
num_res_blocks=vqgan_num_res_blocks,
attn_scales=vqgan_attn_scales,
codebook_size=vqgan_codebook_size,
beta=vqgan_beta).eval().requires_grad_(False).to(device)
self.vqgan.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device))

# init GPT model
self.gpt = models.Seq2SeqGPT(
src_vocab_size=gpt_txt_vocab_size,
tar_vocab_size=vqgan_codebook_size,
src_seq_len=gpt_txt_seq_len,
tar_seq_len=gpt_img_seq_len,
dim=gpt_dim,
num_heads=gpt_num_heads,
num_layers=gpt_num_layers).eval().requires_grad_(False).to(device)
self.gpt.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device))

def __call__(self,
txts,
top_k=64,
top_p=None,
temperature=0.6,
use_fp16=True):
# preprocess
if isinstance(txts, str):
txts = [txts]
assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str)
txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts])

# forward
out_imgs = []
for batch in txt_tokens.split(self.batch_size, dim=0):
# sample
batch = batch.to(self.device, non_blocking=True)
with amp.autocast(enabled=use_fp16):
img_tokens = self.gpt.sample(batch, top_k, top_p, temperature)

# decode
imgs = self.vqgan.decode_from_tokens(img_tokens)
imgs = self._whiten_borders(imgs)
imgs = imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute(
0, 2, 3, 1).cpu().numpy().astype(np.uint8)
imgs = [Image.fromarray(u) for u in imgs]

# append
out_imgs += imgs
return out_imgs

def _whiten_borders(self, imgs):
r"""Remove border artifacts.
"""
imgs[:, :, :18, :] = 1
imgs[:, :, :, :18] = 1
imgs[:, :, -18:, :] = 1
imgs[:, :, :, -18:] = 1
return imgs


class Sole2Shoe(object):

def __init__(
self,
vqgan_dim=128,
vqgan_z_dim=256,
vqgan_dim_mult=[1, 1, 2, 2, 4],
vqgan_num_res_blocks=2,
vqgan_attn_scales=[1.0 / 16],
vqgan_codebook_size=975,
vqgan_beta=0.25,
src_resolution=256,
tar_resolution=512,
gpt_dim=1024,
gpt_num_heads=16,
gpt_num_layers=24,
vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth',
batch_size=12,
device=torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
self.batch_size = batch_size
self.device = device
src_seq_len = (src_resolution // 16)**2
tar_seq_len = (tar_resolution // 16)**2

# init VQGAN model
self.vqgan = models.VQGAN(
dim=vqgan_dim,
z_dim=vqgan_z_dim,
dim_mult=vqgan_dim_mult,
num_res_blocks=vqgan_num_res_blocks,
attn_scales=vqgan_attn_scales,
codebook_size=vqgan_codebook_size,
beta=vqgan_beta).eval().requires_grad_(False).to(device)
self.vqgan.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device))

# init GPT model
self.gpt = models.Seq2SeqGPT(
src_vocab_size=vqgan_codebook_size,
tar_vocab_size=vqgan_codebook_size,
src_seq_len=src_seq_len,
tar_seq_len=tar_seq_len,
dim=gpt_dim,
num_heads=gpt_num_heads,
num_layers=gpt_num_layers).eval().requires_grad_(False).to(device)
self.gpt.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device))

# data transforms
self.transforms = T.Compose([
data.PadToSquare(),
T.Resize(src_resolution),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def __call__(self,
sole_imgs,
top_k=64,
top_p=None,
temperature=0.6,
use_fp16=True,
num_workers=0):
# preprocess
if isinstance(sole_imgs, Image.Image):
sole_imgs = [sole_imgs]
assert isinstance(sole_imgs, (tuple, list)) and isinstance(
sole_imgs[0], Image.Image)
sole_imgs = torch.stack(
parallel(self.transforms, sole_imgs, num_workers), dim=0)

# forward
out_imgs = []
for batch in sole_imgs.split(self.batch_size, dim=0):
# sample
batch = batch.to(self.device)
with amp.autocast(enabled=use_fp16):
sole_tokens = self.vqgan.encode_to_tokens(batch)
shoe_tokens = self.gpt.sample(sole_tokens, top_k, top_p,
temperature)

# decode
shoe_imgs = self.vqgan.decode_from_tokens(shoe_tokens)
shoe_imgs = self._whiten_borders(shoe_imgs)
shoe_imgs = shoe_imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute(
0, 2, 3, 1).cpu().numpy().astype(np.uint8)
shoe_imgs = [Image.fromarray(u) for u in shoe_imgs]

# append
out_imgs += shoe_imgs
return out_imgs

def _whiten_borders(self, imgs):
r"""Remove border artifacts.
"""
imgs[:, :, :18, :] = 1
imgs[:, :, :, :18] = 1
imgs[:, :, -18:, :] = 1
imgs[:, :, :, -18:] = 1
return imgs


class ImageParser(object):

def __init__(
self,
model='SPNet',
num_classes=2,
resolution=800,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
model_with_softmax=False,
checkpoint='models/spnet/sole_segmentation_211219.pth',
batch_size=16,
device=torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
self.batch_size = batch_size
self.device = device

# init model
if checkpoint.endswith('.pt'):
self.net = torch.jit.load(
DOWNLOAD_TO_CACHE(checkpoint)).eval().to(device)
[p.requires_grad_(False) for p in self.net.parameters()]
else:
self.net = getattr(models, model)(
num_classes=num_classes,
pretrained=False).eval().requires_grad_(False).to(device)
self.net.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))
self.softmax = (lambda x, dim: x) if model_with_softmax else F.softmax

# data transforms
self.transforms = T.Compose([
data.PadToSquare(),
T.Resize(resolution),
T.ToTensor(),
T.Normalize(mean, std)
])

def __call__(self, imgs, num_workers=0):
# preprocess
if isinstance(imgs, Image.Image):
imgs = [imgs]
assert isinstance(imgs,
(tuple, list)) and isinstance(imgs[0], Image.Image)
sizes = [u.size for u in imgs]
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)

# forward
masks = []
for batch in imgs.split(self.batch_size, dim=0):
batch = batch.to(self.device, non_blocking=True)
masks.append(self.softmax(self.net(batch), dim=1))

# postprocess
masks = torch.cat(masks, dim=0).unsqueeze(1)
masks = [
F.interpolate(u, v, mode='bilinear', align_corners=False)
for u, v in zip(masks, sizes)
]
return masks


class TextImageMatch(object):

def __init__(
self,
embed_dim=512,
image_size=224,
patch_size=32,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=21128,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth',
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77),
batch_size=64,
device=torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
self.tokenizer = tokenizer
self.batch_size = batch_size
self.device = device

# init model
self.clip = models.CLIP(
embed_dim=embed_dim,
image_size=image_size,
patch_size=patch_size,
vision_dim=vision_dim,
vision_heads=vision_heads,
vision_layers=vision_layers,
vocab_size=vocab_size,
text_len=text_len,
text_dim=text_dim,
text_heads=text_heads,
text_layers=text_layers).eval().requires_grad_(False).to(device)
self.clip.load_state_dict(
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))

# transforms
scale_size = int(image_size * 8 / 7)
self.transforms = T.Compose([
data.PadToSquare(),
T.Resize(scale_size),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(mean, std)
])

def __call__(self, imgs, txts, num_workers=0):
# preprocess
assert isinstance(imgs,
(tuple, list)) and isinstance(imgs[0], Image.Image)
assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str)
txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts])
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)

# forward
scores = []
for img_batch, txt_batch in zip(
imgs.split(self.batch_size, dim=0),
txt_tokens.split(self.batch_size, dim=0)):
img_batch = img_batch.to(self.device)
txt_batch = txt_batch.to(self.device)
xi = F.normalize(self.clip.visual(img_batch), p=2, dim=1)
xt = F.normalize(self.clip.textual(txt_batch), p=2, dim=1)
scores.append((xi * xt).sum(dim=1))
return torch.cat(scores, dim=0)


def taobao_feature_extractor(category='shoes', **kwargs):
r"""Pretrained taobao-search feature extractors.
"""
assert category in ['softall', 'shoes', 'bag']
checkpoint = osp.join(
'models/inception-v1', {
'softall': '1214softall_10.10.0.5000',
'shoes': '1218shoes.v9_7.140.0.1520000',
'bag': '0926bag.v9_6.29.0.140000'
}[category])
app = FeatureExtractor(
model='InceptionV1',
checkpoint=checkpoint,
resolution=224,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
**kwargs)
return app


def singleton_classifier(**kwargs):
r"""Pretrained classifier that finds single-object images.
Supports shoes, apparel, and bag images.
"""
app = Classifier(
model='InceptionV1',
checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth',
num_classes=1,
resolution=224,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
**kwargs)
return app


def orientation_classifier(**kwargs):
r"""Shoes orientation classifier.
"""
app = Classifier(
model='InceptionV1',
checkpoint='models/classifier/shoes-oriendetect-20211026.pth',
num_classes=1,
resolution=224,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
**kwargs)
return app


def fashion_text2image(**kwargs):
r"""Fashion text-to-image generator.
Supports shoe and apparel image generation.
"""
app = Text2Image(
vqgan_dim=128,
vqgan_z_dim=256,
vqgan_dim_mult=[1, 1, 2, 2, 4],
vqgan_num_res_blocks=2,
vqgan_attn_scales=[1.0 / 16],
vqgan_codebook_size=975,
vqgan_beta=0.25,
gpt_txt_vocab_size=21128,
gpt_txt_seq_len=64,
gpt_img_seq_len=1024,
gpt_dim=1024,
gpt_num_heads=16,
gpt_num_layers=24,
vqgan_checkpoint= # noqa E251
'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
gpt_checkpoint= # noqa E251
'models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth',
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64),
**kwargs)
return app


def mindalle_text2image(**kwargs):
r"""Pretrained text2image generator with weights copied from minDALL-E.
"""
app = Text2Image(
vqgan_dim=128,
vqgan_z_dim=256,
vqgan_dim_mult=[1, 1, 2, 2, 4],
vqgan_num_res_blocks=2,
vqgan_attn_scales=[1.0 / 16],
vqgan_codebook_size=16384,
vqgan_beta=0.25,
gpt_txt_vocab_size=16384,
gpt_txt_seq_len=64,
gpt_img_seq_len=256,
gpt_dim=1536,
gpt_num_heads=24,
gpt_num_layers=42,
vqgan_checkpoint='models/minDALLE/1.3B_vqgan.pth',
gpt_checkpoint='models/minDALLE/1.3B_gpt.pth',
tokenizer=data.BPETokenizer(length=64),
**kwargs)
return app


def sole2shoe(**kwargs):
app = Sole2Shoe(
vqgan_dim=128,
vqgan_z_dim=256,
vqgan_dim_mult=[1, 1, 2, 2, 4],
vqgan_num_res_blocks=2,
vqgan_attn_scales=[1.0 / 16],
vqgan_codebook_size=975,
vqgan_beta=0.25,
src_resolution=256,
tar_resolution=512,
gpt_dim=1024,
gpt_num_heads=16,
gpt_num_layers=24,
vqgan_checkpoint= # noqa E251
'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth',
**kwargs)
return app


def sole_parser(**kwargs):
app = ImageParser(
model='SPNet',
num_classes=2,
resolution=800,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
model_with_softmax=False,
checkpoint='models/spnet/sole_segmentation_211219.pth',
**kwargs)
return app


def sod_foreground_parser(**kwargs):
app = ImageParser(
model=None,
num_classes=None,
resolution=448,
mean=[0.488431, 0.466275, 0.403686],
std=[0.222627, 0.21949, 0.22549],
model_with_softmax=True,
checkpoint='models/semseg/sod_model_20201228.pt',
**kwargs)
return app


def fashion_text_image_match(**kwargs):
app = TextImageMatch(
embed_dim=512,
image_size=224,
patch_size=32,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=21128,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth',
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77),
**kwargs)
return app

+ 1074
- 0
modelscope/models/cv/image_to_image_translation/ops/degradation.py
File diff suppressed because it is too large
View File


+ 598
- 0
modelscope/models/cv/image_to_image_translation/ops/diffusion.py View File

@@ -0,0 +1,598 @@
import math

import torch

from .losses import discretized_gaussian_log_likelihood, kl_divergence

__all__ = ['GaussianDiffusion', 'beta_schedule']


def _i(tensor, t, x):
r"""Index tensor using t and format the output according to x.
"""
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t].view(shape).to(x)


def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps

# fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
def fn(u):
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2

betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')


class GaussianDiffusion(object):

def __init__(self,
betas,
mean_type='eps',
var_type='learned_range',
loss_type='mse',
rescale_timesteps=False):
# check input
if not isinstance(betas, torch.DoubleTensor):
betas = torch.tensor(betas, dtype=torch.float64)
assert min(betas) > 0 and max(betas) <= 1
assert mean_type in ['x0', 'x_{t-1}', 'eps']
assert var_type in [
'learned', 'learned_range', 'fixed_large', 'fixed_small'
]
assert loss_type in [
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
]
self.betas = betas
self.num_timesteps = len(betas)
self.mean_type = mean_type
self.var_type = var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps

# alphas
alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat(
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
self.alphas_cumprod_next = torch.cat(
[self.alphas_cumprod[1:],
alphas.new_zeros([1])])

# q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
- self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0
- self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
- 1)

# q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
self.posterior_variance.clamp(1e-20))
self.posterior_mean_coef1 = betas * torch.sqrt(
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
1.0 - self.alphas_cumprod)

def q_sample(self, x0, t, noise=None):
r"""Sample from q(x_t | x_0).
"""
noise = torch.randn_like(x0) if noise is None else noise
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i(
self.sqrt_one_minus_alphas_cumprod, t, x0) * noise

def q_mean_variance(self, x0, t):
r"""Distribution of q(x_t | x_0).
"""
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
var = _i(1.0 - self.alphas_cumprod, t, x0)
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
return mu, var, log_var

def q_posterior_mean_variance(self, x0, xt, t):
r"""Distribution of q(x_{t-1} | x_t, x_0).
"""
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
self.posterior_mean_coef2, t, xt) * xt
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return mu, var, log_var

@torch.no_grad()
def p_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t).
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
# predict distribution of p(x_{t-1} | x_t)
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile,
guide_scale)

# random sample (with optional conditional function)
noise = torch.randn_like(xt)
# no noise when t == 0
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
if condition_fn is not None:
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
mu = mu.float() + var * grad.float()
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
return xt_1, x0

@torch.no_grad()
def p_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
"""
# prepare input
b, c, h, w = noise.size()
xt = noise

# diffusion process
for step in torch.arange(self.num_timesteps).flip(0):
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale)
return xt

def p_mean_variance(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None):
r"""Distribution of p(x_{t-1} | x_t).
"""
# predict distribution
if guide_scale is None:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
else:
# classifier-free guidance
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
assert self.mean_type == 'eps'
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
out = torch.cat(
[
u_out[:, :3] + guide_scale * # noqa W504
(y_out[:, :3] - u_out[:, :3]),
y_out[:, 3:]
],
dim=1)

# compute variance
if self.var_type == 'learned':
out, log_var = out.chunk(2, dim=1)
var = torch.exp(log_var)
elif self.var_type == 'learned_range':
out, fraction = out.chunk(2, dim=1)
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
max_log_var = _i(torch.log(self.betas), t, xt)
fraction = (fraction + 1) / 2.0
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
var = torch.exp(log_var)
elif self.var_type == 'fixed_large':
var = _i(
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
xt)
log_var = torch.log(var)
elif self.var_type == 'fixed_small':
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)

# compute mean and x0
if self.mean_type == 'x_{t-1}':
mu = out # x_{t-1}
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
xt) * xt
elif self.mean_type == 'x0':
x0 = out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'eps':
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)

# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1 # e.g., 0.995
s = torch.quantile(
x0.flatten(1).abs(), percentile,
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0

@torch.no_grad()
def ddim_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
stride = self.num_timesteps // ddim_timesteps

# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)

# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps

# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
alphas = _i(self.alphas_cumprod, t, xt)
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504
(1 - alphas) * # noqa W504
(1 - alphas / alphas_prev))

# random sample
noise = torch.randn_like(xt)
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
return xt_1, x0

@torch.no_grad()
def ddim_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
# prepare input
b, c, h, w = noise.size()
xt = noise

# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale,
ddim_timesteps, eta)
return xt

@torch.no_grad()
def ddim_reverse_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
"""
stride = self.num_timesteps // ddim_timesteps

# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)

# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
alphas_next = _i(
torch.cat(
[self.alphas_cumprod,
self.alphas_cumprod.new_zeros([1])]),
(t + stride).clamp(0, self.num_timesteps), xt)

# reverse sample
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
return mu, x0

@torch.no_grad()
def ddim_reverse_sample_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
# prepare input
b, c, h, w = x0.size()
xt = x0

# reconstruction steps
steps = torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
percentile, guide_scale,
ddim_timesteps)
return xt

@torch.no_grad()
def plms_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
r"""Sample from p(x_{t-1} | x_t) using PLMS.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // plms_timesteps

# function for compute eps
def compute_eps(xt, t):
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile, guide_scale)

# condition
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
- x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)

# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps

# derive eps
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
return eps

# function for compute x_0 and x_{t-1}
def compute_x0(eps, t):
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps

# deterministic sample
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
direction = torch.sqrt(1 - alphas_prev) * eps
# mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
return xt_1, x0

# PLMS sample
eps = compute_eps(xt, t)
if len(eps_cache) == 0:
# 2nd order pseudo improved Euler
xt_1, x0 = compute_x0(eps, t)
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
eps_prime = (eps + eps_next) / 2.0
elif len(eps_cache) == 1:
# 2nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
elif len(eps_cache) == 2:
# 3nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (23 * eps - 16 * eps_cache[-1]
+ 5 * eps_cache[-2]) / 12.0
elif len(eps_cache) >= 3:
# 4nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
- 9 * eps_cache[-3]) / 24.0
xt_1, x0 = compute_x0(eps_prime, t)
return xt_1, x0, eps

@torch.no_grad()
def plms_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
# prepare input
b, c, h, w = noise.size()
xt = noise

# diffusion process
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // plms_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
eps_cache = []
for step in steps:
# PLMS sampling step
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn,
guide_scale, plms_timesteps,
eps_cache)

# update eps cache
eps_cache.append(eps)
if len(eps_cache) >= 4:
eps_cache.pop(0)
return xt

def loss(self, x0, t, model, model_kwargs={}, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
xt = self.q_sample(x0, t, noise=noise)

# compute loss
if self.loss_type in ['kl', 'rescaled_kl']:
loss, _ = self.variational_lower_bound(x0, xt, t, model,
model_kwargs)
if self.loss_type == 'rescaled_kl':
loss = loss * self.num_timesteps
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
out = model(xt, self._scale_timesteps(t), **model_kwargs)

# VLB for variation
loss_vlb = 0.0
if self.var_type in ['learned', 'learned_range']:
out, var = out.chunk(2, dim=1)
frozen = torch.cat([
out.detach(), var
], dim=1) # learn var without affecting the prediction of mean
loss_vlb, _ = self.variational_lower_bound(
x0, xt, t, model=lambda *args, **kwargs: frozen)
if self.loss_type.startswith('rescaled_'):
loss_vlb = loss_vlb * self.num_timesteps / 1000.0

# MSE/L1 for x0/eps
target = {
'eps': noise,
'x0': x0,
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
}[self.mean_type]
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
).abs().flatten(1).mean(dim=1)

# total loss
loss = loss + loss_vlb
return loss

def variational_lower_bound(self,
x0,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None):
# compute groundtruth and predicted distributions
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile)

# compute KL loss
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)

# compute discretized NLL loss (for p(x0 | x1) only)
nll = -discretized_gaussian_log_likelihood(
x0, mean=mu2, log_scale=0.5 * log_var2)
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)

# NLL for p(x0 | x1) and KL otherwise
vlb = torch.where(t == 0, nll, kl)
return vlb, x0

@torch.no_grad()
def variational_lower_bound_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None):
r"""Compute the entire variational lower bound, measured in bits-per-dim.
"""
# prepare input and output
b, c, h, w = x0.size()
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}

# loop
for step in torch.arange(self.num_timesteps).flip(0):
# compute VLB
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, noise)
vlb, pred_x0 = self.variational_lower_bound(
x0, xt, t, model, model_kwargs, clamp, percentile)

# predict eps from x0
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)

# collect metrics
metrics['vlb'].append(vlb)
metrics['x0_mse'].append(
(pred_x0 - x0).square().flatten(1).mean(dim=1))
metrics['mse'].append(
(eps - noise).square().flatten(1).mean(dim=1))
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}

# compute the prior KL term for VLB, measured in bits-per-dim
mu, _, log_var = self.q_mean_variance(x0, t)
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
torch.zeros_like(log_var))
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)

# update metrics
metrics['prior_bits_per_dim'] = kl_prior
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
return metrics

def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * 1000.0 / self.num_timesteps
return t

+ 35
- 0
modelscope/models/cv/image_to_image_translation/ops/losses.py View File

@@ -0,0 +1,35 @@
import math

import torch

__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']


def kl_divergence(mu1, logvar1, mu2, logvar2):
return 0.5 * (
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504
((mu1 - mu2)**2) * torch.exp(-logvar2))


def standard_normal_cdf(x):
r"""A fast approximation of the cumulative distribution function of the standard normal.
"""
return 0.5 * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def discretized_gaussian_log_likelihood(x0, mean, log_scale):
assert x0.shape == mean.shape == log_scale.shape
cx = x0 - mean
inv_stdv = torch.exp(-log_scale)
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x0 < -0.999, log_cdf_plus,
torch.where(x0 > 0.999, log_one_minus_cdf_min,
torch.log(cdf_delta.clamp(min=1e-12))))
assert log_probs.shape == x0.shape
return log_probs

+ 126
- 0
modelscope/models/cv/image_to_image_translation/ops/metrics.py View File

@@ -0,0 +1,126 @@
import numpy as np
import scipy.linalg as linalg
import torch

__all__ = [
'get_fid_net', 'get_is_net', 'compute_fid', 'compute_prdc', 'compute_is'
]


def get_fid_net(resize_input=True, normalize_input=True):
r"""InceptionV3 network for the evaluation of Fréchet Inception Distance (FID).

Args:
resize_input: whether or not to resize the input to (299, 299).
normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1).
"""
from artist.models import InceptionV3
return InceptionV3(
output_blocks=(3, ),
resize_input=resize_input,
normalize_input=normalize_input,
requires_grad=False,
use_fid_inception=True).eval().requires_grad_(False)


def get_is_net(resize_input=True, normalize_input=True):
r"""InceptionV3 network for the evaluation of Inception Score (IS).

Args:
resize_input: whether or not to resize the input to (299, 299).
normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1).
"""
from artist.models import InceptionV3
return InceptionV3(
output_blocks=(4, ),
resize_input=resize_input,
normalize_input=normalize_input,
requires_grad=False,
use_fid_inception=False).eval().requires_grad_(False)


@torch.no_grad()
def compute_fid(real_feats, fake_feats, eps=1e-6):
r"""Compute Fréchet Inception Distance (FID).

Args:
real_feats: [N, C].
fake_feats: [N, C].
"""
# check inputs
if isinstance(real_feats, torch.Tensor):
real_feats = real_feats.cpu().numpy().astype(np.float_)
if isinstance(fake_feats, torch.Tensor):
fake_feats = fake_feats.cpu().numpy().astype(np.float_)

# real statistics
mu1 = np.mean(real_feats, axis=0)
sigma1 = np.cov(real_feats, rowvar=False)

# fake statistics
mu2 = np.mean(fake_feats, axis=0)
sigma2 = np.cov(fake_feats, rowvar=False)

# compute covmean
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
print(
f'FID calculation produces singular product; adding {eps} to diagonal of cov',
flush=True)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real

# compute Fréchet distance
diff = mu1 - mu2
fid = diff.dot(diff) + np.trace(sigma1) + np.trace(
sigma2) - 2 * np.trace(covmean)
return fid.item()


@torch.no_grad()
def compute_prdc(real_feats, fake_feats, knn=5):
r"""Compute precision, recall, density, and coverage given two manifolds.

Args:
real_feats: [N, C].
fake_feats: [N, C].
knn: the number of nearest neighbors to consider.
"""
# distances
real_kth = -(-torch.cdist(real_feats, real_feats)).topk(
k=knn, dim=1)[0][:, -1]
fake_kth = -(-torch.cdist(fake_feats, fake_feats)).topk(
k=knn, dim=1)[0][:, -1]
dists = torch.cdist(real_feats, fake_feats)

# metrics
precision = (dists < real_kth.unsqueeze(1)).any(
dim=0).float().mean().item()
recall = (dists < fake_kth.unsqueeze(0)).any(dim=1).float().mean().item()
density = (dists < real_kth.unsqueeze(1)).float().sum(
dim=0).mean().item() / knn
coverage = (dists.min(dim=1)[0] < real_kth).float().mean().item()
return precision, recall, density, coverage


@torch.no_grad()
def compute_is(logits, num_splits=10):
preds = logits.softmax(dim=1).cpu().numpy()
split_scores = []
for k in range(num_splits):
part = preds[k * (len(logits) // num_splits):(k + 1)
* (len(logits) // num_splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)

+ 220
- 0
modelscope/models/cv/image_to_image_translation/ops/random_color.py View File

@@ -0,0 +1,220 @@
import colorsys
import random

__all__ = ['RandomColor', 'rand_color']

COLORMAP = {
'blue': {
'hue_range': [179, 257],
'lower_bounds': [[20, 100], [30, 86], [40, 80], [50, 74], [60, 60],
[70, 52], [80, 44], [90, 39], [100, 35]]
},
'green': {
'hue_range': [63, 178],
'lower_bounds': [[30, 100], [40, 90], [50, 85], [60, 81], [70, 74],
[80, 64], [90, 50], [100, 40]]
},
'monochrome': {
'hue_range': [0, 0],
'lower_bounds': [[0, 0], [100, 0]]
},
'orange': {
'hue_range': [19, 46],
'lower_bounds': [[20, 100], [30, 93], [40, 88], [50, 86], [60, 85],
[70, 70], [100, 70]]
},
'pink': {
'hue_range': [283, 334],
'lower_bounds': [[20, 100], [30, 90], [40, 86], [60, 84], [80, 80],
[90, 75], [100, 73]]
},
'purple': {
'hue_range': [258, 282],
'lower_bounds': [[20, 100], [30, 87], [40, 79], [50, 70], [60, 65],
[70, 59], [80, 52], [90, 45], [100, 42]]
},
'red': {
'hue_range': [-26, 18],
'lower_bounds': [[20, 100], [30, 92], [40, 89], [50, 85], [60, 78],
[70, 70], [80, 60], [90, 55], [100, 50]]
},
'yellow': {
'hue_range': [47, 62],
'lower_bounds': [[25, 100], [40, 94], [50, 89], [60, 86], [70, 84],
[80, 82], [90, 80], [100, 75]]
}
}


class RandomColor(object):

def __init__(self, seed=None):
self.colormap = COLORMAP
self.random = random.Random(seed)

for color_name, color_attrs in self.colormap.items():
lower_bounds = color_attrs['lower_bounds']
s_min = lower_bounds[0][0]
s_max = lower_bounds[len(lower_bounds) - 1][0]

b_min = lower_bounds[len(lower_bounds) - 1][1]
b_max = lower_bounds[0][1]

self.colormap[color_name]['saturation_range'] = [s_min, s_max]
self.colormap[color_name]['brightness_range'] = [b_min, b_max]

def generate(self, hue=None, luminosity=None, count=1, format_='hex'):
colors = []
for _ in range(count):
# First we pick a hue (H)
H = self.pick_hue(hue)

# Then use H to determine saturation (S)
S = self.pick_saturation(H, hue, luminosity)

# Then use S and H to determine brightness (B).
B = self.pick_brightness(H, S, luminosity)

# Then we return the HSB color in the desired format
colors.append(self.set_format([H, S, B], format_))

return colors

def pick_hue(self, hue):
hue_range = self.get_hue_range(hue)
hue = self.random_within(hue_range)

# Instead of storing red as two seperate ranges,
# we group them, using negative numbers
if (hue < 0):
hue += 360

return hue

def pick_saturation(self, hue, hue_name, luminosity):

if luminosity == 'random':
return self.random_within([0, 100])

if hue_name == 'monochrome':
return 0

saturation_range = self.get_saturation_range(hue)

s_min = saturation_range[0]
s_max = saturation_range[1]

if luminosity == 'bright':
s_min = 55
elif luminosity == 'dark':
s_min = s_max - 10
elif luminosity == 'light':
s_max = 55

return self.random_within([s_min, s_max])

def pick_brightness(self, H, S, luminosity):
b_min = self.get_minimum_brightness(H, S)
b_max = 100

if luminosity == 'dark':
b_max = b_min + 20
elif luminosity == 'light':
b_min = (b_max + b_min) / 2
elif luminosity == 'random':
b_min = 0
b_max = 100

return self.random_within([b_min, b_max])

def set_format(self, hsv, format_):
if 'hsv' in format_:
color = hsv
elif 'rgb' in format_:
color = self.hsv_to_rgb(hsv)
elif 'hex' in format_:
r, g, b = self.hsv_to_rgb(hsv)
return '#%02x%02x%02x' % (r, g, b)
else:
return 'unrecognized format'

if 'Array' in format_ or format_ == 'hex':
return color
else:
prefix = format_[:3]
color_values = [str(x) for x in color]
return '%s(%s)' % (prefix, ', '.join(color_values))

def get_minimum_brightness(self, H, S):
lower_bounds = self.get_color_info(H)['lower_bounds']

for i in range(len(lower_bounds) - 1):
s1 = lower_bounds[i][0]
v1 = lower_bounds[i][1]

s2 = lower_bounds[i + 1][0]
v2 = lower_bounds[i + 1][1]

if s1 <= S <= s2:
m = (v2 - v1) / (s2 - s1)
b = v1 - m * s1

return m * S + b

return 0

def get_hue_range(self, color_input):
if color_input and color_input.isdigit():
number = int(color_input)

if 0 < number < 360:
return [number, number]

elif color_input and color_input in self.colormap:
color = self.colormap[color_input]
if 'hue_range' in color:
return color['hue_range']

else:
return [0, 360]

def get_saturation_range(self, hue):
return self.get_color_info(hue)['saturation_range']

def get_color_info(self, hue):
# Maps red colors to make picking hue easier
if 334 <= hue <= 360:
hue -= 360

for color_name, color in self.colormap.items():
if color['hue_range'] and color['hue_range'][0] <= hue <= color[
'hue_range'][1]:
return self.colormap[color_name]

# this should probably raise an exception
return 'Color not found'

def random_within(self, r):
return self.random.randint(int(r[0]), int(r[1]))

@classmethod
def hsv_to_rgb(cls, hsv):
h, s, v = hsv
h = 1 if h == 0 else h
h = 359 if h == 360 else h

h = float(h) / 360
s = float(s) / 100
v = float(v) / 100

rgb = colorsys.hsv_to_rgb(h, s, v)
return [int(c * 255) for c in rgb]


def rand_color():
generator = RandomColor()
hue = random.choice(list(COLORMAP.keys()))
color = generator.generate(hue=hue, count=1, format_='rgb')[0]
color = color[color.find('(') + 1:color.find(')')]
color = tuple([int(u) for u in color.split(',')])
return color

+ 79
- 0
modelscope/models/cv/image_to_image_translation/ops/random_mask.py View File

@@ -0,0 +1,79 @@
import cv2
import numpy as np

__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop']


def make_irregular_mask(w,
h,
max_angle=4,
max_length=200,
max_width=100,
min_strokes=1,
max_strokes=5,
mode='line'):
# initialize mask
assert mode in ['line', 'circle', 'square']
mask = np.zeros((h, w), np.float32)

# draw strokes
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
for i in range(num_strokes):
x1 = np.random.randint(w)
y1 = np.random.randint(h)
for j in range(1 + np.random.randint(5)):
angle = 0.01 + np.random.randint(max_angle)
if i % 2 == 0:
angle = 2 * 3.1415926 - angle
length = 10 + np.random.randint(max_length)
radius = 5 + np.random.randint(max_width)
x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w)
y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h)
if mode == 'line':
cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius)
elif mode == 'circle':
cv2.circle(
mask, (x1, y1), radius=radius, color=1.0, thickness=-1)
elif mode == 'square':
radius = radius // 2
mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1
x1, y1 = x2, y2
return mask


def make_rectangle_mask(w,
h,
margin=10,
min_size=30,
max_size=150,
min_strokes=1,
max_strokes=4):
# initialize mask
mask = np.zeros((h, w), np.float32)

# draw rectangles
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
for i in range(num_strokes):
box_w = np.random.randint(min_size, max_size)
box_h = np.random.randint(min_size, max_size)
x1 = np.random.randint(margin, w - margin - box_w + 1)
y1 = np.random.randint(margin, h - margin - box_h + 1)
mask[y1:y1 + box_h, x1:x1 + box_w] = 1
return mask


def make_uncrop(w, h):
# initialize mask
mask = np.zeros((h, w), np.float32)

# randomly halve the image
side = np.random.choice([0, 1, 2, 3])
if side == 0:
mask[:h // 2, :] = 1
elif side == 1:
mask[h // 2:, :] = 1
elif side == 2:
mask[:, :w // 2] = 1
elif side == 2:
mask[:, w // 2:] = 1
return mask

+ 152
- 0
modelscope/models/cv/image_to_image_translation/ops/svd.py View File

@@ -0,0 +1,152 @@
r"""SVD of linear degradation matrices described in the paper
``Denoising Diffusion Restoration Models.''
@article{kawar2022denoising,
title={Denoising Diffusion Restoration Models},
author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song},
year={2022},
journal={arXiv preprint arXiv:2201.11793},
}
"""
import torch

__all__ = ['SVD', 'IdentitySVD', 'DenoiseSVD', 'ColorizationSVD']


class SVD(object):
r"""SVD decomposition of a matrix, i.e., H = UDV^T.
NOTE: assume that all inputs (i.e., h, x) are of shape [B, CHW].
"""

def __init__(self, h):
self.u, self.d, self.v = torch.svd(h, some=False)
self.ut = self.u.t()
self.vt = self.v.t()
self.d[self.d < 1e-3] = 0

def U(self, x):
return torch.matmul(self.u, x)

def Ut(self, x):
return torch.matmul(self.ut, x)

def V(self, x):
return torch.matmul(self.v, x)

def Vt(self, x):
return torch.matmul(self.vt, x)

@property
def D(self):
return self.d

def H(self, x):
return self.U(self.D * self.Vt(x)[:, :self.D.size(0)])

def Ht(self, x):
return self.V(self._pad(self.D * self.Ut(x)[:, :self.D.size(0)]))

def Hinv(self, x):
r"""Multiplies x by the pseudo inverse of H.
"""
x = self.Ut(x)
x[:, :self.D.size(0)] = x[:, :self.D.size(0)] / self.D
return self.V(self._pad(x))

def _pad(self, x):
o = x.new_zeros(x.size(0), self.v.size(0))
o[:, :self.u.size(0)] = x.view(x.size(0), -1)
return o

def to(self, *args, **kwargs):
r"""Update the data type and device of UDV matrices.
"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
setattr(self, k, v.to(*args, **kwargs))
return self


class IdentitySVD(SVD):

def __init__(self, c, h, w):
self.d = torch.ones(c * h * w)

def U(self, x):
return x.clone()

def Ut(self, x):
return x.clone()

def V(self, x):
return x.clone()

def Vt(self, x):
return x.clone()

def H(self, x):
return x.clone()

def Ht(self, x):
return x.clone()

def Hinv(self, x):
return x.clone()

def _pad(self, x):
return x.clone()


class DenoiseSVD(SVD):

def __init__(self, c, h, w):
self.num_entries = c * h * w
self.d = torch.ones(self.num_entries)

def U(self, x):
return x.clone()

def Ut(self, x):
return x.clone()

def V(self, x):
return x.clone()

def Vt(self, x):
return x.clone()

def _pad(self, x):
return x.clone()


class ColorizationSVD(SVD):

def __init__(self, c, h, w):
self.color_dim = c
self.num_pixels = h * w
self.u, self.d, self.v = torch.svd(torch.ones(1, c) / c, some=False)
self.vt = self.v.t()

def U(self, x):
return self.u[0, 0] * x

def Ut(self, x):
return self.u[0, 0] * x

def V(self, x):
return torch.einsum('ij,bjn->bin', self.v,
x.view(x.size(0), self.color_dim,
self.num_pixels)).flatten(1)

def Vt(self, x):
return torch.einsum('ij,bjn->bin', self.vt,
x.view(x.size(0), self.color_dim,
self.num_pixels)).flatten(1)

@property
def D(self):
return self.d.repeat(self.num_pixels)

def _pad(self, x):
o = x.new_zeros(x.size(0), self.color_dim * self.num_pixels)
o[:, :self.num_pixels] = x
return o

+ 224
- 0
modelscope/models/cv/image_to_image_translation/ops/utils.py View File

@@ -0,0 +1,224 @@
import base64
import binascii
import hashlib
import math
import os
import os.path as osp
import zipfile
from io import BytesIO
from multiprocessing.pool import ThreadPool as Pool

import cv2
import json
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

from .random_color import rand_color

__all__ = [
'ceil_divide', 'to_device', 'rand_name', 'ema', 'parallel', 'unzip',
'load_state_dict', 'inverse_indices', 'detect_duplicates', 'md5', 'rope',
'format_state', 'breakup_grid', 'viz_anno_geometry', 'image_to_base64'
]

TFS_CLIENT = None


def ceil_divide(a, b):
return int(math.ceil(a / b))


def to_device(batch, device, non_blocking=False):
if isinstance(batch, (list, tuple)):
return type(batch)([to_device(u, device, non_blocking) for u in batch])
elif isinstance(batch, dict):
return type(batch)([(k, to_device(v, device, non_blocking))
for k, v in batch.items()])
elif isinstance(batch, torch.Tensor):
return batch.to(device, non_blocking=non_blocking)
return batch


def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name


@torch.no_grad()
def ema(net_ema, net, beta, copy_buffer=False):
assert 0.0 <= beta <= 1.0
for p_ema, p in zip(net_ema.parameters(), net.parameters()):
p_ema.copy_(p.lerp(p_ema, beta))
if copy_buffer:
for b_ema, b in zip(net_ema.buffers(), net.buffers()):
b_ema.copy_(b)


def parallel(func, args_list, num_workers=32, timeout=None):
assert isinstance(args_list, list)
if not isinstance(args_list[0], tuple):
args_list = [(args, ) for args in args_list]
if num_workers == 0:
return [func(*args) for args in args_list]
with Pool(processes=num_workers) as pool:
results = [pool.apply_async(func, args) for args in args_list]
results = [res.get(timeout=timeout) for res in results]
return results


def unzip(filename, dst_dir=None):
if dst_dir is None:
dst_dir = osp.dirname(filename)
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(dst_dir)


def load_state_dict(module, state_dict, drop_prefix=''):
# find incompatible key-vals
src, dst = state_dict, module.state_dict()
if drop_prefix:
src = type(src)([
(k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v)
for k, v in src.items()
])
missing = [k for k in dst if k not in src]
unexpected = [k for k in src if k not in dst]
unmatched = [
k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape
]

# keep only compatible key-vals
incompatible = set(unexpected + unmatched)
src = type(src)([(k, v) for k, v in src.items() if k not in incompatible])
module.load_state_dict(src, strict=False)

# report incompatible key-vals
if len(missing) != 0:
print(' Missing: ' + ', '.join(missing), flush=True)
if len(unexpected) != 0:
print(' Unexpected: ' + ', '.join(unexpected), flush=True)
if len(unmatched) != 0:
print(' Shape unmatched: ' + ', '.join(unmatched), flush=True)


def inverse_indices(indices):
r"""Inverse map of indices.
E.g., if A[indices] == B, then B[inv_indices] == A.
"""
inv_indices = torch.empty_like(indices)
inv_indices[indices] = torch.arange(len(indices)).to(indices)
return inv_indices


def detect_duplicates(feats, thr=0.9):
assert feats.ndim == 2

# compute simmat
feats = F.normalize(feats, p=2, dim=1)
simmat = torch.mm(feats, feats.T)
simmat.triu_(1)
torch.cuda.synchronize()

# detect duplicates
mask = ~simmat.gt(thr).any(dim=0)
return torch.where(mask)[0]


def md5(filename):
with open(filename, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()


def rope(x):
r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C].
"""
# reshape
shape = x.shape
x = x.view(x.size(0), -1, x.size(-1))
l, c = x.shape[-2:]
assert c % 2 == 0
half = c // 2

# apply rotary position embedding on x
sinusoid = torch.outer(
torch.arange(l).to(x),
torch.pow(10000, -torch.arange(half).to(x).div(half)))
sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
x1, x2 = x.chunk(2, dim=-1)
x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

# reshape back
return x.view(shape)


def format_state(state, filename=None):
r"""For comparing/aligning state_dict.
"""
content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()])
if filename:
with open(filename, 'w') as f:
f.write(content)


def breakup_grid(img, grid_size):
r"""The inverse operator of ``torchvision.utils.make_grid``.
"""
# params
nrow = img.height // grid_size
ncol = img.width // grid_size
wrow = wcol = 2 # NOTE: use default values here

# collect grids
grids = []
for i in range(nrow):
for j in range(ncol):
x1 = j * grid_size + (j + 1) * wcol
y1 = i * grid_size + (i + 1) * wrow
grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size)))
return grids


def viz_anno_geometry(item):
r"""Visualize an annotation item from SmartLabel.
"""
if isinstance(item, str):
item = json.loads(item)
assert isinstance(item, dict)

# read image
orig_img = read_image(item['image_url'], retry=100)
img = cv2.cvtColor(np.asarray(orig_img), cv2.COLOR_BGR2RGB)

# loop over geometries
for geometry in item['sd_result']['items']:
# params
poly_img = img.copy()
color = rand_color()
points = np.array(geometry['meta']['geometry']).round().astype(int)
line_color = tuple([int(u * 0.55) for u in color])

# draw polygons
poly_img = cv2.fillPoly(poly_img, pts=[points], color=color)
poly_img = cv2.polylines(
poly_img,
pts=[points],
isClosed=True,
color=line_color,
thickness=2)

# mixing
img = np.clip(0.25 * img + 0.75 * poly_img, 0, 255).astype(np.uint8)
return orig_img, Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))


def image_to_base64(img, format='JPEG'):
buffer = BytesIO()
img.save(buffer, format=format)
code = base64.b64encode(buffer.getvalue()).decode('utf-8')
return code

+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -15,6 +15,7 @@ if TYPE_CHECKING:
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
from .image_colorization_pipeline import ImageColorizationPipeline
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .video_category_pipeline import VideoCategoryPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline


+ 325
- 0
modelscope/pipelines/cv/image_to_image_translation_pipeline.py View File

@@ -0,0 +1,325 @@
import io
import os.path as osp
import sys
from typing import Any, Dict

import cv2
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.utils import save_image

import modelscope.models.cv.image_to_image_translation.data as data
import modelscope.models.cv.image_to_image_translation.models as models
import modelscope.models.cv.image_to_image_translation.ops as ops
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.models.cv.image_to_image_translation.model_translation import \
UNet
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


def save_grid(imgs, filename, nrow=5):
save_image(
imgs.clamp(-1, 1), filename, range=(-1, 1), normalize=True, nrow=nrow)


@PIPELINES.register_module(
Tasks.image_generation, module_name=Pipelines.image2image_translation)
class Image2ImageTranslationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.repetition = 4
# load autoencoder model
ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path)
logger.info(f'loading autoencoder model from {ae_model_path}')
self.autoencoder = models.VQAutoencoder(
dim=self.cfg.Params.ae.ae_dim,
z_dim=self.cfg.Params.ae.ae_z_dim,
dim_mult=self.cfg.Params.ae.ae_dim_mult,
attn_scales=self.cfg.Params.ae.ae_attn_scales,
codebook_size=self.cfg.Params.ae.ae_codebook_size).eval(
).requires_grad_(False).to(self._device) # noqa E123
self.autoencoder.load_state_dict(
torch.load(ae_model_path, map_location=self._device))
logger.info('load autoencoder model done')

# load palette model
palette_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading palette model from {palette_model_path}')
self.palette = UNet(
resolution=self.cfg.Params.unet.unet_resolution,
in_dim=self.cfg.Params.unet.unet_in_dim,
dim=self.cfg.Params.unet.unet_dim,
context_dim=self.cfg.Params.unet.unet_context_dim,
out_dim=self.cfg.Params.unet.unet_out_dim,
dim_mult=self.cfg.Params.unet.unet_dim_mult,
num_heads=self.cfg.Params.unet.unet_num_heads,
head_dim=None,
num_res_blocks=self.cfg.Params.unet.unet_res_blocks,
attn_scales=self.cfg.Params.unet.unet_attn_scales,
num_classes=self.cfg.Params.unet.unet_num_classes + 1,
dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_(
False).to(self._device)
self.palette.load_state_dict(
torch.load(palette_model_path, map_location=self._device))
logger.info('load palette model done')

# diffusion
logger.info('Initialization diffusion ...')
betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule,
self.cfg.Params.diffusion.num_timesteps)
self.diffusion = ops.GaussianDiffusion(
betas=betas,
mean_type=self.cfg.Params.diffusion.mean_type,
var_type=self.cfg.Params.diffusion.var_type,
loss_type=self.cfg.Params.diffusion.loss_type,
rescale_timesteps=False)

self.transforms = T.Compose([
data.PadToSquare(),
T.Resize(
self.cfg.DATA.scale_size,
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std)
])

def preprocess(self, input: Input) -> Dict[str, Any]:
if len(input) == 3: # colorization
_, input_type, save_path = input
elif len(input) == 4: # uncropping or in-painting
_, meta, input_type, save_path = input
if input_type == 0: # uncropping
assert meta in ['up', 'down', 'left', 'right']
direction = meta

list_ = []
for i in range(len(input) - 2):
input_img = input[i]
if input_img in ['up', 'down', 'left', 'right']:
continue
if isinstance(input_img, str):
if input_type == 2 and i == 0:
logger.info('Loading image by origin way ... ')
bytes = File.read(input_img)
img = Image.open(io.BytesIO(bytes))
assert len(img.split()) == 4
else:
img = load_image(input_img)
elif isinstance(input_img, PIL.Image.Image):
img = input_img.convert('RGB')
elif isinstance(input_img, np.ndarray):
if len(input_img.shape) == 2:
input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR)
img = input_img[:, :, ::-1]
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
list_.append(img)
img_list = []
if input_type != 2:
for img in list_:
img = self.transforms(img)
imgs = torch.unsqueeze(img, 0)
imgs = imgs.to(self._device)
img_list.append(imgs)
elif input_type == 2:
mask, masked_img = list_[0], list_[1]
img = self.transforms(masked_img.convert('RGB'))
mask = torch.from_numpy(
np.array(
mask.resize((img.shape[2], img.shape[1])),
dtype=np.float32)[:, :, -1] / 255.0).unsqueeze(0)
img = (1 - mask) * img + mask * torch.randn_like(img).clamp_(-1, 1)
imgs = img.unsqueeze(0).to(self._device)
b, c, h, w = imgs.shape
y = torch.LongTensor([self.cfg.Classes.class_id]).to(self._device)

if input_type == 0:
assert len(img_list) == 1
result = {
'image_data': img_list[0],
'c': c,
'h': h,
'w': w,
'direction': direction,
'type': input_type,
'y': y,
'save_path': save_path
}
elif input_type == 1:
assert len(img_list) == 1
result = {
'image_data': img_list[0],
'c': c,
'h': h,
'w': w,
'type': input_type,
'y': y,
'save_path': save_path
}
elif input_type == 2:
result = {
'image_data': imgs,
# 'image_mask': mask,
'c': c,
'h': h,
'w': w,
'type': input_type,
'y': y,
'save_path': save_path
}
return result

@torch.no_grad()
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
type_ = input['type']
if type_ == 0:
# Uncropping
img = input['image_data']
direction = input['direction']
y = input['y']

# fix seed
torch.manual_seed(1 * 8888)
torch.cuda.manual_seed(1 * 8888)

logger.info(f'Processing {direction} uncropping')
img = img.clone()
i_y = y.repeat(self.repetition, 1)
if direction == 'up':
img[:, :, input['h'] // 2:, :] = torch.randn_like(
img[:, :, input['h'] // 2:, :])
elif direction == 'down':
img[:, :, :input['h'] // 2, :] = torch.randn_like(
img[:, :, :input['h'] // 2, :])
elif direction == 'left':
img[:, :, :,
input['w'] // 2:] = torch.randn_like(img[:, :, :,
input['w'] // 2:])
elif direction == 'right':
img[:, :, :, :input['w'] // 2] = torch.randn_like(
img[:, :, :, :input['w'] // 2])
i_concat = self.autoencoder.encode(img).repeat(
self.repetition, 1, 1, 1)

# sample images
x0 = self.diffusion.ddim_sample_loop(
noise=torch.randn_like(i_concat),
model=self.palette,
model_kwargs=[{
'y': i_y,
'concat': i_concat
}, {
'y':
torch.full_like(i_y,
self.cfg.Params.unet.unet_num_classes),
'concat':
i_concat
}],
guide_scale=1.0,
clamp=None,
ddim_timesteps=50,
eta=1.0)
i_gen_imgs = self.autoencoder.decode(x0)
save_grid(i_gen_imgs, input['save_path'], nrow=4)
return {OutputKeys.OUTPUT_IMG: i_gen_imgs}

elif type_ == 1:
# Colorization #
img = input['image_data']
y = input['y']
# fix seed
torch.manual_seed(1 * 8888)
torch.cuda.manual_seed(1 * 8888)

logger.info('Processing Colorization')
img = img.clone()
img = img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
i_concat = self.autoencoder.encode(img).repeat(
self.repetition, 1, 1, 1)
i_y = y.repeat(self.repetition, 1)

# sample images
x0 = self.diffusion.ddim_sample_loop(
noise=torch.randn_like(i_concat),
model=self.palette,
model_kwargs=[{
'y': i_y,
'concat': i_concat
}, {
'y':
torch.full_like(i_y,
self.cfg.Params.unet.unet_num_classes),
'concat':
i_concat
}],
guide_scale=1.0,
clamp=None,
ddim_timesteps=50,
eta=0.0)
i_gen_imgs = self.autoencoder.decode(x0)
save_grid(i_gen_imgs, input['save_path'], nrow=4)
return {OutputKeys.OUTPUT_IMG: i_gen_imgs}
elif type_ == 2:
# Combination #
logger.info('Processing Combination')

# prepare inputs
img = input['image_data']
concat = self.autoencoder.encode(img).repeat(
self.repetition, 1, 1, 1)
y = torch.LongTensor([126]).unsqueeze(0).to(self._device).repeat(
self.repetition, 1)

# sample images
x0 = self.diffusion.ddim_sample_loop(
noise=torch.randn_like(concat),
model=self.palette,
model_kwargs=[{
'y': y,
'concat': concat
}, {
'y':
torch.full_like(y, self.cfg.Params.unet.unet_num_classes),
'concat':
concat
}],
guide_scale=1.0,
clamp=None,
ddim_timesteps=50,
eta=1.0)
i_gen_imgs = self.autoencoder.decode(x0)
save_grid(i_gen_imgs, input['save_path'], nrow=4)
return {OutputKeys.OUTPUT_IMG: i_gen_imgs}
else:
raise TypeError(
f'input type should be 0 (Uncropping), 1 (Colorization), 2 (Combation)'
f' but got {type_}')

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 38
- 0
tests/pipelines/test_image2image_translation.py View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import shutil
import unittest

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class Image2ImageTranslationTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
r"""We provide three translation modes, i.e., uncropping, colorization and combination.
You can pass the following parameters for different mode.
1. Uncropping Mode:
result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 'left', 0, 'result.jpg'))
2. Colorization Mode:
result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 1, 'result.jpg'))
3. Combination Mode:
just like the following code.
"""
img2img_gen_pipeline = pipeline(
Tasks.image_generation,
model='damo/cv_latent_diffusion_image2image_translation')
result = img2img_gen_pipeline(
('data/test/images/img2img_input_mask.png',
'data/test/images/img2img_input_masked_img.png', 2,
'result.jpg')) # combination mode

print(f'output: {result}.')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save