Browse Source

DALL-E 2: 修复dev/dalle2_1分支问题,增加测试代码,本地测试通过

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10037492
master
xuangen.hlh yingda.chen 3 years ago
parent
commit
f7f29ed1ff
13 changed files with 2638 additions and 3 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +5
    -1
      modelscope/models/multi_modal/__init__.py
  3. +1
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/__init__.py
  4. +318
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/clip.py
  5. +322
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/decoder.py
  6. +641
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py
  7. +265
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/model.py
  8. +170
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/prior.py
  9. +199
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py
  10. +466
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py
  11. +205
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/xglm.py
  12. +5
    -2
      modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py
  13. +40
    -0
      tests/pipelines/test_multi_stage_diffusion.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -72,6 +72,7 @@ class Models(object):
gemm = 'gemm-generative-multi-modal'
mplug = 'mplug'
diffusion = 'diffusion-text-to-image-synthesis'
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'



+ 5
- 1
modelscope/models/multi_modal/__init__.py View File

@@ -14,6 +14,8 @@ if TYPE_CHECKING:
from .ofa_for_all_tasks import OfaForAllTasks
from .ofa_for_text_to_image_synthesis_model import \
OfaForTextToImageSynthesis
from .multi_stage_diffusion import \
MultiStageDiffusionForTextToImageSynthesis

else:
_import_structure = {
@@ -25,7 +27,9 @@ else:
'mplug_for_all_tasks': ['MPlugForAllTasks'],
'ofa_for_all_tasks': ['OfaForAllTasks'],
'ofa_for_text_to_image_synthesis_model':
['OfaForTextToImageSynthesis']
['OfaForTextToImageSynthesis'],
'multi_stage_diffusion':
['MultiStageDiffusionForTextToImageSynthesis']
}

import sys


+ 1
- 0
modelscope/models/multi_modal/multi_stage_diffusion/__init__.py View File

@@ -0,0 +1 @@
from .model import MultiStageDiffusionForTextToImageSynthesis

+ 318
- 0
modelscope/models/multi_modal/multi_stage_diffusion/clip.py View File

@@ -0,0 +1,318 @@
# The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['CLIP']


def to_fp16(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
m.weight.data = m.weight.data.half()
if m.bias is not None:
m.bias.data = m.bias.data.half()
elif hasattr(m, 'head'):
p = getattr(m, 'head')
p.data = p.data.half()


class QuickGELU(nn.Module):

def forward(self, x):
return x * torch.sigmoid(1.702 * x)


class LayerNorm(nn.LayerNorm):
r"""Subclass of nn.LayerNorm to handle fp16.
"""

def forward(self, x):
return super(LayerNorm, self).forward(x.float()).type_as(x)


class SelfAttention(nn.Module):

def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
assert dim % num_heads == 0
super(SelfAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)

# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.attn_dropout = nn.Dropout(attn_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(proj_dropout)

def forward(self, x, mask=None):
r"""x: [B, L, C].
mask: [*, L, L].
"""
b, l, _, n = *x.size(), self.num_heads

# compute query, key, and value
q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
q = q.reshape(l, b * n, -1).transpose(0, 1)
k = k.reshape(l, b * n, -1).transpose(0, 1)
v = v.reshape(l, b * n, -1).transpose(0, 1)

# compute attention
attn = self.scale * torch.bmm(q, k.transpose(1, 2))
if mask is not None:
attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
attn = self.attn_dropout(attn)

# gather context
x = torch.bmm(attn, v)
x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)

# output
x = self.proj(x)
x = self.proj_dropout(x)
return x


class AttentionBlock(nn.Module):

def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
super(AttentionBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads

# layers
self.norm1 = LayerNorm(dim)
self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
nn.Dropout(proj_dropout))

def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.mlp(self.norm2(x))
return x


class VisionTransformer(nn.Module):

def __init__(self,
image_size=224,
patch_size=16,
dim=768,
out_dim=512,
num_heads=12,
num_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
assert image_size % patch_size == 0
super(VisionTransformer, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.num_patches = (image_size // patch_size)**2

# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(
gain * torch.randn(1, self.num_patches + 1, dim))
self.dropout = nn.Dropout(embedding_dropout)

# transformer
self.pre_norm = LayerNorm(dim)
self.transformer = nn.Sequential(*[
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim)

# head
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))

def forward(self, x):
b, dtype = x.size(0), self.head.dtype
x = x.type(dtype)

# patch-embedding
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
dim=1)
x = self.dropout(x + self.pos_embedding.type(dtype))
x = self.pre_norm(x)

# transformer
x = self.transformer(x)

# head
x = self.post_norm(x)
x = torch.mm(x[:, 0, :], self.head)
return x

def fp16(self):
return self.apply(to_fp16)


class TextTransformer(nn.Module):

def __init__(self,
vocab_size,
text_len,
dim=512,
out_dim=512,
num_heads=8,
num_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
super(TextTransformer, self).__init__()
self.vocab_size = vocab_size
self.text_len = text_len
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers

# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
self.dropout = nn.Dropout(embedding_dropout)

# transformer
self.transformer = nn.ModuleList([
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
for _ in range(num_layers)
])
self.norm = LayerNorm(dim)

# head
gain = 1.0 / math.sqrt(dim)
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))

# causal attention mask
self.register_buffer('attn_mask',
torch.tril(torch.ones(1, text_len, text_len)))

def forward(self, x):
eot, dtype = x.argmax(dim=-1), self.head.dtype

# embeddings
x = self.dropout(
self.token_embedding(x).type(dtype)
+ self.pos_embedding.type(dtype))

# transformer
for block in self.transformer:
x = block(x, self.attn_mask)

# head
x = self.norm(x)
x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
return x

def fp16(self):
return self.apply(to_fp16)


class CLIP(nn.Module):

def __init__(self,
embed_dim=512,
image_size=224,
patch_size=16,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=49408,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
super(CLIP, self).__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vocab_size = vocab_size
self.text_len = text_len
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers

# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout)
self.textual = TextTransformer(
vocab_size=vocab_size,
text_len=text_len,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))

def forward(self, imgs, txt_tokens):
r"""imgs: [B, C, H, W] of torch.float32.
txt_tokens: [B, T] of torch.long.
"""
xi = self.visual(imgs)
xt = self.textual(txt_tokens)

# normalize features
xi = F.normalize(xi, p=2, dim=1)
xt = F.normalize(xt, p=2, dim=1)

# logits
scale = self.log_scale.exp()
logits_i2t = scale * torch.mm(xi, xt.t())
logits_t2i = scale * torch.mm(xt, xi.t())
return logits_i2t, logits_t2i

def init_weights(self):
# embeddings
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)

# attentions
for modality in ['visual', 'textual']:
dim = self.vision_dim if modality == 'visual' else 'textual'
transformer = getattr(self, modality).transformer
proj_gain = (1.0 / math.sqrt(dim)) * (
1.0 / math.sqrt(2 * transformer.num_layers))
attn_gain = 1.0 / math.sqrt(dim)
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
for block in transformer.layers:
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
nn.init.normal_(block.mlp[2].weight, std=proj_gain)

def fp16(self):
return self.apply(to_fp16)

+ 322
- 0
modelscope/models/multi_modal/multi_stage_diffusion/decoder.py View File

@@ -0,0 +1,322 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['Decoder']


def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()

# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x


class Resample(nn.Module):

def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
assert scale_factor in [0.5, 1.0, 2.0]
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.scale_factor = scale_factor
self.use_conv = use_conv

# layers
if scale_factor == 2.0:
self.resample = nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
nn.Conv2d(in_dim, out_dim, 3, padding=1)
if use_conv else nn.Identity())
elif scale_factor == 0.5:
self.resample = nn.Conv2d(
in_dim, out_dim, 3, stride=2,
padding=1) if use_conv else nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.resample = nn.Identity()

def forward(self, x):
return self.resample(x)


class ResidualBlock(nn.Module):

def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
scale_factor=1.0,
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.scale_factor = scale_factor

# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)

# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)

def forward(self, x, e):
identity = self.resample(x)
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x


class AttentionBlock(nn.Module):

def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)

# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)

# zero out the last layer params
nn.init.zeros_(self.proj.weight)

def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim

# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)

# compute attention
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)

# gather context
x = torch.matmul(v, attn.transpose(-1, -2))
x = x.reshape(b, c, h, w)

# output
x = self.proj(x)
return x + identity


class Decoder(nn.Module):

def __init__(self,
in_dim=3,
dim=512,
y_dim=512,
context_dim=512,
out_dim=6,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
resblock_resample=True,
use_scale_shift_norm=True,
dropout=0.1):
embed_dim = dim * 4
super(Decoder, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.resblock_resample = resblock_resample
self.use_scale_shift_norm = use_scale_shift_norm

# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0

# embeddings
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.y_embedding = nn.Sequential(
nn.Linear(y_dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.context_embedding = nn.Sequential(
nn.Linear(y_dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, context_dim * 4))

# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim, embed_dim, out_dim,
use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim
self.encoder.append(block)
shortcut_dims.append(out_dim)

# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
if resblock_resample:
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 0.5,
dropout)
else:
downsample = Resample(
out_dim, out_dim, 0.5, use_conv=True)
shortcut_dims.append(out_dim)
scale /= 2.0
self.encoder.append(downsample)

# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout),
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout)
])

# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim

# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
if resblock_resample:
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 2.0,
dropout)
else:
upsample = Resample(
out_dim, out_dim, 2.0, use_conv=True)
scale *= 2.0
block.append(upsample)
self.decoder.append(block)

# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))

# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)

def forward(self, x, t, y):
# embeddings
e = self.time_embedding(sinusoidal_embedding(
t, self.dim)) + self.y_embedding(y)
context = self.context_embedding(y).view(-1, 4, self.context_dim)

# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, e, context)
xs.append(x)

# middle
for block in self.middle:
x = self._forward_single(block, x, e, context)

# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, e, context)

# head
x = self.head(x)
return x

def _forward_single(self, module, x, e, context):
if isinstance(module, ResidualBlock):
x = module(x, e)
elif isinstance(module, AttentionBlock):
x = module(x, context)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context)
else:
x = module(x)
return x

+ 641
- 0
modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py View File

@@ -0,0 +1,641 @@
# The implementation here is modified based on latent diffusion, publicly available
# at https://github.com/CompVis/latent-diffusion.

import math

import torch

__all__ = ['GaussianDiffusion', 'beta_schedule']


def kl_divergence(mu1, logvar1, mu2, logvar2):
u1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
u2 = ((mu1 - mu2)**2) * torch.exp(-logvar2)
return 0.5 * (u1 + u2)


def standard_normal_cdf(x):
r"""A fast approximation of the cumulative distribution function of the standard normal.
"""
return 0.5 * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def discretized_gaussian_log_likelihood(x0, mean, log_scale):
assert x0.shape == mean.shape == log_scale.shape
cx = x0 - mean
inv_stdv = torch.exp(-log_scale)
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x0 < -0.999, log_cdf_plus,
torch.where(x0 > 0.999, log_one_minus_cdf_min,
torch.log(cdf_delta.clamp(min=1e-12))))
assert log_probs.shape == x0.shape
return log_probs


def _i(tensor, t, x):
r"""Index tensor using t and format the output according to x.
"""
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t].view(shape).to(x)


def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
fn_t1 = math.cos((t1 + 0.008) / 1.008 * math.pi / 2)**2
fn_t2 = math.cos((t2 + 0.008) / 1.008 * math.pi / 2)**2
betas.append(min(1.0 - fn_t2 / fn_t1, 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')


class GaussianDiffusion(object):

def __init__(self,
betas,
mean_type='eps',
var_type='learned_range',
loss_type='mse',
rescale_timesteps=False):
# check input
if not isinstance(betas, torch.DoubleTensor):
betas = torch.tensor(betas, dtype=torch.float64)
assert min(betas) > 0 and max(betas) <= 1
assert mean_type in ['x0', 'x_{t-1}', 'eps']
assert var_type in [
'learned', 'learned_range', 'fixed_large', 'fixed_small'
]
assert loss_type in [
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
]
self.betas = betas
self.num_timesteps = len(betas)
self.mean_type = mean_type
self.var_type = var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps

# alphas
alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat(
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
self.alphas_cumprod_next = torch.cat(
[self.alphas_cumprod[1:],
alphas.new_zeros([1])])

# q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
- self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0
- self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
- 1)

# q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
self.posterior_variance.clamp(1e-20))
self.posterior_mean_coef1 = betas * torch.sqrt(
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
1.0 - self.alphas_cumprod)

def q_sample(self, x0, t, noise=None):
r"""Sample from q(x_t | x_0).
"""
noise = torch.randn_like(x0) if noise is None else noise
u1 = _i(self.sqrt_alphas_cumprod, t, x0) * x0
u2 = _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
return u1 + u2

def q_mean_variance(self, x0, t):
r"""Distribution of q(x_t | x_0).
"""
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
var = _i(1.0 - self.alphas_cumprod, t, x0)
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
return mu, var, log_var

def q_posterior_mean_variance(self, x0, xt, t):
r"""Distribution of q(x_{t-1} | x_t, x_0).
"""
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
self.posterior_mean_coef2, t, xt) * xt
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return mu, var, log_var

@torch.no_grad()
def p_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t).
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
# predict distribution of p(x_{t-1} | x_t)
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile,
guide_scale)

# random sample (with optional conditional function)
noise = torch.randn_like(xt)
shape = (-1, *((1, ) * (xt.ndim - 1)))
mask = t.ne(0).float().view(shape) # no noise when t == 0
if condition_fn is not None:
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
mu = mu.float() + var * grad.float()
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
return xt_1, x0

@torch.no_grad()
def p_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
"""
# prepare input
b = noise.size(0)
xt = noise

# diffusion process
for step in torch.arange(self.num_timesteps).flip(0):
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale)
return xt

def p_mean_variance(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None):
r"""Distribution of p(x_{t-1} | x_t).
"""
# predict distribution
if guide_scale is None:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
else:
# classifier-free guidance
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
cond = self.var_type.startswith('fixed')
dim = y_out.size(1) if cond else y_out.size(1) // 2
u1 = u_out[:, :dim]
u2 = guide_scale * (y_out[:, :dim] - u_out[:, :dim])
out = torch.cat([u1 + u2, y_out[:, dim:]], dim=1)

# compute variance
if self.var_type == 'learned':
out, log_var = out.chunk(2, dim=1)
var = torch.exp(log_var)
elif self.var_type == 'learned_range':
out, fraction = out.chunk(2, dim=1)
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
max_log_var = _i(torch.log(self.betas), t, xt)
fraction = (fraction + 1) / 2.0
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
var = torch.exp(log_var)
elif self.var_type == 'fixed_large':
var = _i(
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
xt)
log_var = torch.log(var)
elif self.var_type == 'fixed_small':
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)

# compute mean and x0
if self.mean_type == 'x_{t-1}':
mu = out # x_{t-1}
u1 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu
u2 = _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
xt) * xt
x0 = u1 - u2
elif self.mean_type == 'x0':
x0 = out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'eps':
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
x0 = u1 - u2
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)

# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1 # e.g., 0.995
s = torch.quantile(
x0.flatten(1).abs(), percentile,
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0

@torch.no_grad()
def ddim_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
r"""Sample from p(x_{t-1} | x_t) using DDIM.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // ddim_timesteps

# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = u1 / u2
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)

# eps -> x0
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
x0 = u1 - u2

# derive variables
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = u1 / u2
alphas = _i(self.alphas_cumprod, t, xt)
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
u1 = (1 - alphas_prev) / (1 - alphas)
u2 = (1 - alphas / alphas_prev)
sigmas = eta * torch.sqrt(u1 * u2)

# random sample
noise = torch.randn_like(xt)
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
return xt_1, x0

@torch.no_grad()
def ddim_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
# prepare input
b = noise.size(0)
xt = noise

# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale,
ddim_timesteps, eta)
return xt

@torch.no_grad()
def ddim_reverse_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
"""
stride = self.num_timesteps // ddim_timesteps

# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)

# derive variables
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = u1 / u2

alphas_next = _i(
torch.cat(
[self.alphas_cumprod,
self.alphas_cumprod.new_zeros([1])]),
(t + stride).clamp(0, self.num_timesteps), xt)

# reverse sample
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
return mu, x0

@torch.no_grad()
def ddim_reverse_sample_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
# prepare input
b = x0.size(0)
xt = x0

# reconstruction steps
steps = torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
percentile, guide_scale,
ddim_timesteps)
return xt

@torch.no_grad()
def plms_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
r"""Sample from p(x_{t-1} | x_t) using PLMS.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // plms_timesteps

# function for compute eps
def compute_eps(xt, t):
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile, guide_scale)

# condition
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = u1 / u2
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)

# eps -> x0
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
x0 = u1 - u2

# derive eps
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = u1 / u2
return eps

# function for compute x_0 and x_{t-1}
def compute_x0(eps, t):
# eps -> x0
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
x0 = u1 - u2

# deterministic sample
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
direction = torch.sqrt(1 - alphas_prev) * eps
# mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
return xt_1, x0

# PLMS sample
eps = compute_eps(xt, t)
if len(eps_cache) == 0:
# 2nd order pseudo improved Euler
xt_1, x0 = compute_x0(eps, t)
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
eps_prime = (eps + eps_next) / 2.0
elif len(eps_cache) == 1:
# 2nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
elif len(eps_cache) == 2:
# 3nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (23 * eps - 16 * eps_cache[-1]
+ 5 * eps_cache[-2]) / 12.0
elif len(eps_cache) >= 3:
# 4nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
- 9 * eps_cache[-3]) / 24.0
xt_1, x0 = compute_x0(eps_prime, t)
return xt_1, x0, eps

@torch.no_grad()
def plms_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
# prepare input
b = noise.size(0)
xt = noise

# diffusion process
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // plms_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
eps_cache = []
for step in steps:
# PLMS sampling step
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn,
guide_scale, plms_timesteps,
eps_cache)

# update eps cache
eps_cache.append(eps)
if len(eps_cache) >= 4:
eps_cache.pop(0)
return xt

def loss(self, x0, t, model, model_kwargs={}, noise=None, input_x0=None):
noise = torch.randn_like(x0) if noise is None else noise
input_x0 = x0 if input_x0 is None else input_x0
xt = self.q_sample(input_x0, t, noise=noise)

# compute loss
if self.loss_type in ['kl', 'rescaled_kl']:
loss, _ = self.variational_lower_bound(x0, xt, t, model,
model_kwargs)
if self.loss_type == 'rescaled_kl':
loss = loss * self.num_timesteps
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
out = model(xt, self._scale_timesteps(t), **model_kwargs)

# VLB for variation
loss_vlb = 0.0
if self.var_type in ['learned', 'learned_range']:
out, var = out.chunk(2, dim=1)
frozen = torch.cat([
out.detach(), var
], dim=1) # learn var without affecting the prediction of mean
loss_vlb, _ = self.variational_lower_bound(
x0, xt, t, model=lambda *args, **kwargs: frozen)
if self.loss_type.startswith('rescaled_'):
loss_vlb = loss_vlb * self.num_timesteps / 1000.0

# MSE/L1 for x0/eps
target = {
'eps': noise,
'x0': x0,
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
}[self.mean_type]
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
).abs().flatten(1).mean(dim=1)

# total loss
loss = loss + loss_vlb
return loss

def variational_lower_bound(self,
x0,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None):
# compute groundtruth and predicted distributions
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile)

# compute KL loss
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)

# compute discretized NLL loss (for p(x0 | x1) only)
nll = -discretized_gaussian_log_likelihood(
x0, mean=mu2, log_scale=0.5 * log_var2)
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)

# NLL for p(x0 | x1) and KL otherwise
vlb = torch.where(t == 0, nll, kl)
return vlb, x0

@torch.no_grad()
def variational_lower_bound_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None):
r"""Compute the entire variational lower bound, measured in bits-per-dim.
"""
# prepare input and output
b = x0.size(0)
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}

# loop
for step in torch.arange(self.num_timesteps).flip(0):
# compute VLB
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, noise)
vlb, pred_x0 = self.variational_lower_bound(
x0, xt, t, model, model_kwargs, clamp, percentile)

# predict eps from x0
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = u1 / u2

# collect metrics
metrics['vlb'].append(vlb)
metrics['x0_mse'].append(
(pred_x0 - x0).square().flatten(1).mean(dim=1))
metrics['mse'].append(
(eps - noise).square().flatten(1).mean(dim=1))
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}

# compute the prior KL term for VLB, measured in bits-per-dim
mu, _, log_var = self.q_mean_variance(x0, t)
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
torch.zeros_like(log_var))
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)

# update metrics
metrics['prior_bits_per_dim'] = kl_prior
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
return metrics

def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * 1000.0 / self.num_timesteps
return t

+ 265
- 0
modelscope/models/multi_modal/multi_stage_diffusion/model.py View File

@@ -0,0 +1,265 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math
import os.path as osp
from typing import Any, Dict

import json
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.multi_stage_diffusion.clip import CLIP
from modelscope.models.multi_modal.multi_stage_diffusion.decoder import Decoder
from modelscope.models.multi_modal.multi_stage_diffusion.gaussian_diffusion import (
GaussianDiffusion, beta_schedule)
from modelscope.models.multi_modal.multi_stage_diffusion.prior import Prior
from modelscope.models.multi_modal.multi_stage_diffusion.tokenizer import (
CLIPTokenizer, XGLMTokenizer)
from modelscope.models.multi_modal.multi_stage_diffusion.upsampler import (
Upsampler256, Upsampler1024)
from modelscope.models.multi_modal.multi_stage_diffusion.xglm import XGLM
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['MultiStageDiffusionForTextToImageSynthesis']


def make_diffusion(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None,
mean_type='eps',
var_type='fixed_small'):
betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta)
diffusion = GaussianDiffusion(
betas, mean_type=mean_type, var_type=var_type)
return diffusion


class UnCLIP(nn.Module):

def __init__(self, model_dir):
super(UnCLIP, self).__init__()
self.model_dir = model_dir
self.config = json.load(open(f'{model_dir}/{ModelFile.CONFIGURATION}'))

# modules
self.clip = CLIP(**self.config['clip']).fp16()
self.xglm = XGLM(**self.config['xglm'])
self.prior = Prior(**self.config['prior'])
self.decoder = Decoder(**self.config['decoder'])
self.upsampler256 = Upsampler256(**self.config['upsampler256'])
self.upsampler1024 = Upsampler1024(**self.config['upsampler1024'])

# diffusions
self.prior_diffusion = make_diffusion(**self.config['prior_diffusion'])
self.decoder_diffusion = make_diffusion(
**self.config['decoder_diffusion'])
self.upsampler256_diffusion = make_diffusion(
**self.config['upsampler256_diffusion'])
self.upsampler1024_diffusion = make_diffusion(
**self.config['upsampler1024_diffusion'])

# tokenizers
self.clip_tokenizer = CLIPTokenizer(
bpe_path=f'{model_dir}/bpe_simple_vocab_16e6.txt.gz')
self.xglm_tokenizer = XGLMTokenizer(model_dir=model_dir)

def forward(self, *args, **kwargs):
raise NotImplementedError(
'"forward" is not implemented. Use "synthesis" instead.')

@torch.no_grad()
def synthesis(self,
text='A photo of a confused grizzly bear in calculus class.',
tokenizer='clip',
batch_size=4,
timesteps_prior=100,
timesteps_64=50,
timesteps_256=20,
timesteps_1024=20,
guide_prior=3.0,
guide_64=7.0,
guide_256=3.0,
guide_1024=3.0,
eta_prior=0.0,
eta_64=0.0,
eta_256=0.0,
eta_1024=0.0):
device = next(self.parameters()).device

# check params
assert all([
t > 0 and t <= 1000 for t in
[timesteps_prior, timesteps_64, timesteps_256, timesteps_1024]
])
assert all([
g > 1 and g < 15
for g in [guide_prior, guide_64, guide_256, guide_1024]
])
assert all([
e >= 0 and e <= 1.0
for e in [eta_prior, eta_64, eta_256, eta_1024]
])
assert batch_size >= 1 and batch_size <= 16

# tokenize the text
if tokenizer == 'clip':
y = F.normalize(
self.clip.textual(self.clip_tokenizer([text]).to(device)),
p=2,
dim=1)
zero_y = F.normalize(
self.clip.textual(self.clip_tokenizer(['']).to(device)),
p=2,
dim=1)
elif tokenizer == 'xglm':
y = F.normalize(
self.xglm(*to_device(self.xglm_tokenizer([text]), device)),
p=2,
dim=1)
zero_y = F.normalize(
self.xglm(*to_device(self.xglm_tokenizer(['']), device)),
p=2,
dim=1)
else:
raise ValueError(
f'Expected tokenizer to be one of "clip" or "xglm", but got {tokenizer}'
)
y = math.sqrt(y.size(1)) * y.repeat(batch_size, 1)
zero_y = math.sqrt(zero_y.size(1)) * zero_y.repeat(batch_size, 1)

# synthesis
with amp.autocast(enabled=True):
# prior
x0 = self.prior_diffusion.ddim_sample_loop(
noise=torch.randn_like(y),
model=self.prior,
model_kwargs=[{
'y': y
}, {
'y': zero_y
}],
guide_scale=guide_prior,
ddim_timesteps=timesteps_prior,
eta=eta_prior)

# decoder
imgs64 = self.decoder_diffusion.ddim_sample_loop(
noise=torch.randn(batch_size, 3, 64, 64).to(device),
model=self.decoder,
model_kwargs=[{
'y': x0
}, {
'y': torch.zeros_like(x0)
}],
guide_scale=guide_64,
percentile=0.995,
ddim_timesteps=timesteps_64,
eta=eta_64).clamp_(-1, 1)

# upsampler256
imgs256 = F.interpolate(
imgs64, scale_factor=4.0, mode='bilinear', align_corners=False)
imgs256 = self.upsampler256_diffusion.ddim_sample_loop(
noise=torch.randn_like(imgs256),
model=self.upsampler256,
model_kwargs=[{
'y': y,
'concat': imgs256
}, {
'y': zero_y,
'concat': imgs256
}],
guide_scale=guide_256,
percentile=0.995,
ddim_timesteps=timesteps_256,
eta=eta_256).clamp_(-1, 1)

# upsampler1024
imgs1024 = F.interpolate(
imgs256,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop(
noise=torch.randn_like(imgs1024),
model=self.upsampler1024,
model_kwargs=[{
'y': y,
'concat': imgs1024
}, {
'y': zero_y,
'concat': imgs1024
}],
guide_scale=guide_1024,
percentile=0.995,
ddim_timesteps=timesteps_1024,
eta=eta_1024).clamp_(-1, 1)

# output ([B, C, H, W] within range [0, 1])
imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu()
imgs1024 = [
Image.fromarray(np.array(u, dtype=np.uint8)) for u in imgs1024
]
return imgs1024


@MODELS.register_module(
Tasks.text_to_image_synthesis, module_name=Models.multi_stage_diffusion)
class MultiStageDiffusionForTextToImageSynthesis(TorchModel):

def __init__(self, model_dir, device_id=-1):
super().__init__(model_dir=model_dir, device_id=device_id)
model = UnCLIP(model_dir=model_dir)
pretrained_params = torch.load(
osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
model.load_state_dict(pretrained_params)
model.eval()

self.device_id = device_id
if self.device_id >= 0:
self.device = torch.device(f'cuda:{self.device_id}')
model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
else:
self.device = torch.device('cpu')
logger.info('Use CPU for inference')
self.model = model

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(input, dict):
raise ValueError(
f'Expected the input to be a dictionary, but got {type(input)}'
)
if 'text' not in input:
raise ValueError('input should contain "text", but not found')

# ddim sampling
imgs = self.model.synthesis(
text=input.get('text'),
tokenizer=input.get('tokenizer', 'clip'),
batch_size=input.get('batch_size', 4),
timesteps_prior=input.get('timesteps_prior', 100),
timesteps_64=input.get('timesteps_64', 50),
timesteps_256=input.get('timesteps_256', 20),
timesteps_1024=input.get('timesteps_1024', 20),
guide_prior=input.get('guide_prior', 3.0),
guide_64=input.get('guide_64', 7.0),
guide_256=input.get('guide_256', 3.0),
guide_1024=input.get('guide_1024', 3.0),
eta_prior=input.get('eta_prior', 0.0),
eta_64=input.get('eta_64', 0.0),
eta_256=input.get('eta_256', 0.0),
eta_1024=input.get('eta_1024', 0.0))
imgs = [np.array(u)[..., ::-1] for u in imgs]
return imgs

+ 170
- 0
modelscope/models/multi_modal/multi_stage_diffusion/prior.py View File

@@ -0,0 +1,170 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['Prior']


def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()

# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x


class SelfAttention(nn.Module):

def __init__(self, dim, num_heads):
assert dim % num_heads == 0
super(SelfAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = math.pow(self.head_dim, -0.25)

# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)

def forward(self, x, mask):
b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim

# compute query, key, value
q, k, v = self.to_qkv(x).view(b, l, n * 3, c).chunk(3, dim=2)

# compute attention
attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale)
if mask is not None:
attn = attn.masked_fill(mask[:, :, :l, :l] == 0, float('-inf'))
attn = F.softmax(attn.float(), dim=-1).type(attn.dtype)

# gather context
x = torch.einsum('bnij,bjnc->binc', attn, v)
x = x.reshape(b, l, -1)

# output
x = self.proj(x)
return x


class AttentionBlock(nn.Module):

def __init__(self, dim, num_heads):
super(AttentionBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads

# layers
self.norm1 = nn.LayerNorm(dim)
self.attn = SelfAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))

def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x


class Prior(nn.Module):

def __init__(self, dim=2048, clip_dim=768, num_heads=32, num_layers=24):
super(Prior, self).__init__()
self.dim = dim
self.clip_dim = clip_dim
self.num_heads = num_heads
self.num_layers = num_layers

# embeddings
self.text_embedding = nn.Sequential(
nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.vision_embedding = nn.Sequential(
nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.eos_embedding = nn.Parameter(torch.zeros(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.zeros(1, 4, dim))

# transformer
self.blocks = nn.ModuleList(
[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
self.norm = nn.LayerNorm(dim)

# head
self.head = nn.Linear(dim, clip_dim)

# causal attention mask
self.register_buffer('attn_mask', torch.tril(torch.ones(1, 1, 4, 4)))

# initialize weights
self.init_weights()

def forward(self, x, t, y):
r"""x: [B, C].
t: [B].
y: [B, C].
"""
b = x.size(0)

# embeddings of shape [B, L + 4, C]
u1 = sinusoidal_embedding(t, self.dim)
u2 = [
self.text_embedding(y).unsqueeze(1),
self.time_embedding(u1).unsqueeze(1),
self.vision_embedding(x).unsqueeze(1),
self.eos_embedding.repeat(b, 1, 1)
]
x = self.pos_embedding + torch.cat(u2, dim=1)

# transformer
for block in self.blocks:
x = block(x, self.attn_mask)
x = self.norm(x)

# head
x = self.head(x[:, -1])
return x

def init_weights(self):
std = 0.02 / math.sqrt(2.0 * self.num_layers)
for name, m in self.named_modules():
if name.endswith('attn.proj') or name.endswith('ffn.2'):
# smaller std for output layers
nn.init.normal_(m.weight, std=std)
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)

def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay':
0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups

+ 199
- 0
modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py View File

@@ -0,0 +1,199 @@
# The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP.

import gzip
import html
from functools import lru_cache

import ftfy
import regex as re
import torch
from transformers import AutoTokenizer

__all__ = ['CLIPTokenizer', 'XGLMTokenizer']


@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs


def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()


def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text


class SimpleTokenizer(object):

def __init__(self, bpe_path):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)

if not pairs:
return token + '</w>'

while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens

def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text


class CLIPTokenizer(object):
r"""CLIP tokenizer, adapted from https://github.com/openai/CLIP.
"""

def __init__(self, bpe_path, length=77):
self.bpe_path = bpe_path
self.length = length

# init tokenizer
self.tokenizer = SimpleTokenizer(bpe_path=bpe_path)
self.sos_token = self.tokenizer.encoder['<|startoftext|>']
self.eos_token = self.tokenizer.encoder['<|endoftext|>']
self.vocab_size = len(self.tokenizer.encoder)

def __call__(self, sequence):
if isinstance(sequence, str):
return torch.LongTensor(self._tokenizer(sequence))
elif isinstance(sequence, list):
return torch.LongTensor([self._tokenizer(u) for u in sequence])
else:
raise TypeError(
f'Expected the "sequence" to be a string or a list, but got {type(sequence)}'
)

def _tokenizer(self, text):
tokens = self.tokenizer.encode(text)[:self.length - 2]
tokens = [self.sos_token] + tokens + [self.eos_token]
tokens = tokens + [0] * (self.length - len(tokens))
return tokens


class XGLMTokenizer(object):
r"""A wrapper of HuggingFace's XGLM tokenizer.
"""

def __init__(self, model_dir, length=77, **kwargs):
self.length = length
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, **kwargs)
self.vocab_size = self.tokenizer.vocab_size

def __call__(self, sequence, **kwargs):
_kwargs = {
'return_tensors': 'pt',
'padding': 'max_length',
'truncation': True,
'max_length': self.length
}
_kwargs.update(**kwargs)
tokens = self.tokenizer(sequence, **_kwargs)
return tokens.input_ids, tokens.attention_mask

+ 466
- 0
modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py View File

@@ -0,0 +1,466 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['Upsampler256', 'Upsampler1024']


def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()

# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x


class Resample(nn.Module):

def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
assert scale_factor in [0.5, 1.0, 2.0]
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.scale_factor = scale_factor
self.use_conv = use_conv

# layers
if scale_factor == 2.0:
self.resample = nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
nn.Conv2d(in_dim, out_dim, 3, padding=1)
if use_conv else nn.Identity())
elif scale_factor == 0.5:
self.resample = nn.Conv2d(
in_dim, out_dim, 3, stride=2,
padding=1) if use_conv else nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.resample = nn.Identity()

def forward(self, x):
return self.resample(x)


class ResidualBlock(nn.Module):

def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
scale_factor=1.0,
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.scale_factor = scale_factor

# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)

# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)

def forward(self, x, e):
identity = self.resample(x)
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x


class AttentionBlock(nn.Module):

def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)

# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)

# zero out the last layer params
nn.init.zeros_(self.proj.weight)

def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim

# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)

# compute attention
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)

# gather context
x = torch.matmul(v, attn.transpose(-1, -2))
x = x.reshape(b, c, h, w)

# output
x = self.proj(x)
return x + identity


class Upsampler256(nn.Module):

def __init__(self,
in_dim=6,
dim=320,
y_dim=768,
context_dim=512,
out_dim=3,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 8],
resblock_resample=True,
use_scale_shift_norm=True,
dropout=0.1):
embed_dim = dim * 4
super(Upsampler256, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.resblock_resample = resblock_resample
self.use_scale_shift_norm = use_scale_shift_norm

# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0

# embeddings
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.y_embedding = nn.Sequential(
nn.Linear(y_dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.context_embedding = nn.Sequential(
nn.Linear(y_dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, context_dim * 4))

# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim, embed_dim, out_dim,
use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim
self.encoder.append(block)
shortcut_dims.append(out_dim)

# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
if resblock_resample:
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 0.5,
dropout)
else:
downsample = Resample(
out_dim, out_dim, 0.5, use_conv=True)
shortcut_dims.append(out_dim)
scale /= 2.0
self.encoder.append(downsample)

# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout),
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout)
])

# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim

# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
if resblock_resample:
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 2.0,
dropout)
else:
upsample = Resample(
out_dim, out_dim, 2.0, use_conv=True)
scale *= 2.0
block.append(upsample)
self.decoder.append(block)

# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))

# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)

def forward(self, x, t, y, concat):
# embeddings
x = torch.cat([x, concat], dim=1)
e = self.time_embedding(sinusoidal_embedding(
t, self.dim)) + self.y_embedding(y)
context = self.context_embedding(y).view(-1, 4, self.context_dim)

# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, e, context)
xs.append(x)

# middle
for block in self.middle:
x = self._forward_single(block, x, e, context)

# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, e, context)

# head
x = self.head(x)
return x

def _forward_single(self, module, x, e, context):
if isinstance(module, ResidualBlock):
x = module(x, e)
elif isinstance(module, AttentionBlock):
x = module(x, context)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context)
else:
x = module(x)
return x


class Upsampler1024(nn.Module):

def __init__(self,
in_dim=6,
dim=192,
y_dim=768,
out_dim=3,
dim_mult=[1, 1, 2, 2, 4, 4],
num_res_blocks=2,
resblock_resample=True,
use_scale_shift_norm=True,
dropout=0.0):
embed_dim = dim * 4
super(Upsampler1024, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.resblock_resample = resblock_resample
self.use_scale_shift_norm = use_scale_shift_norm

# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0

# embedding
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.y_embedding = nn.Sequential(
nn.Linear(y_dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))

# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual block
block = nn.ModuleList([
ResidualBlock(in_dim, embed_dim, out_dim,
use_scale_shift_norm, 1.0, dropout)
])
shortcut_dims.append(out_dim)
in_dim = out_dim
self.encoder.append(block)

# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
if resblock_resample:
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 0.5,
dropout)
else:
downsample = Resample(
out_dim, out_dim, 0.5, use_conv=True)
shortcut_dims.append(out_dim)
scale /= 2.0
self.encoder.append(downsample)

# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout),
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout)
])

# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual block
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, use_scale_shift_norm, 1.0, dropout)
])
in_dim = out_dim

# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
if resblock_resample:
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 2.0,
dropout)
else:
upsample = Resample(
out_dim, out_dim, 2.0, use_conv=True)
scale *= 2.0
block.append(upsample)
self.decoder.append(block)

# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))

# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)

def forward(self, x, t, y, concat):
# embedding
x = torch.cat([x, concat], dim=1)
e = self.time_embedding(sinusoidal_embedding(
t, self.dim)) + self.y_embedding(y)

# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, e)
xs.append(x)

# middle
for block in self.middle:
x = self._forward_single(block, x, e)

# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, e)

# head
x = self.head(x)
return x

def _forward_single(self, module, x, e):
if isinstance(module, ResidualBlock):
x = module(x, e)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e)
else:
x = module(x)
return x

+ 205
- 0
modelscope/models/multi_modal/multi_stage_diffusion/xglm.py View File

@@ -0,0 +1,205 @@
# The implementation here is modified based on HuggingFace XGLM, publicly available
# at https://github.com/huggingface/transformers.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['XGLM']


def sinusoidal_embedding(seq_len, dim, pad_token=None):
half = dim // 2
sinusoid = torch.outer(
torch.arange(seq_len, dtype=torch.float32),
torch.pow(10000,
-torch.arange(half, dtype=torch.float32).div(half - 1)))
x = torch.cat([torch.sin(sinusoid), torch.cos(sinusoid)], dim=1)
if dim % 2 == 1:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
if pad_token is not None:
x[pad_token, :] = 0
return x


class SinusoidalEmbedding(nn.Module):

def __init__(self, seq_len, dim, pad_token):
super(SinusoidalEmbedding, self).__init__()
self.seq_len = seq_len
self.dim = dim
self.pad_token = pad_token
self.register_buffer('weight',
sinusoidal_embedding(seq_len + 2, dim, pad_token))

def forward(self, tokens):
mask = tokens.ne(self.pad_token).long()
indices = torch.cumsum(mask, dim=1) * mask + self.pad_token
pos_embeds = self.weight.index_select(0, indices.view(-1)).view(
*tokens.shape, -1)
return pos_embeds


class GELU(nn.Module):

def forward(self, x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class SelfAttention(nn.Module):

def __init__(self, dim, num_heads, dropout=0.1):
assert dim % num_heads == 0
super(SelfAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)

# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
r"""x: [B, L, C].
mask: [B, *, L, L] or None.
"""
b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim

# compute query, key, value
q = self.q(x).view(b, l, n, c)
k = self.k(x).view(b, l, n, c)
v = self.v(x).view(b, l, n, c)

# compute attention
attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k)
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)

# gather context
x = torch.einsum('bnij,bjnc->binc', attn, v)
x = x.reshape(b, l, -1)

# output
x = self.o(x)
x = self.dropout(x)
return x


class AttentionBlock(nn.Module):

def __init__(self, dim, ffn_dim, ffn_act, num_heads, dropout=0.1):
assert ffn_act in ['gelu', 'relu']
super(AttentionBlock, self).__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.ffn_act = ffn_act
self.num_heads = num_heads

# layers
self.norm1 = nn.LayerNorm(dim)
self.attn = SelfAttention(dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
GELU() if ffn_act == 'gelu' else nn.ReLU(inplace=True),
nn.Linear(ffn_dim, dim), nn.Dropout(dropout))

def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x


class XGLM(nn.Module):
r"""A multilingual GPT model with an embedding head.
"""

def __init__(self,
vocab_size=256008,
max_seq_len=2048,
dim=1024,
ffn_dim=4096,
ffn_act='gelu',
embed_dim=768,
num_heads=16,
num_layers=24,
pad_token=1,
dropout=0.1):
super(XGLM, self).__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.dim = dim
self.ffn_dim = ffn_dim
self.ffn_act = ffn_act
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pad_token = pad_token
self.scale = math.sqrt(dim) # rescale token embedings

# layers
self.token_embedding = nn.Embedding(vocab_size, dim, pad_token)
self.pos_embedding = SinusoidalEmbedding(max_seq_len, dim, pad_token)
self.eos_embedding = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
AttentionBlock(dim, ffn_dim, ffn_act, num_heads, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(dim)
self.head = nn.Linear(dim, embed_dim, bias=False)

# causal attention mask
self.register_buffer(
'attn_mask',
torch.tril(torch.ones(1, 1, 1 + max_seq_len, 1 + max_seq_len)))

# init weights
self.apply(self.init_weights)

def forward(self, tokens, mask=None):
r"""tokens: [B, L].
mask: [B, L].
"""
b, seq_len = tokens.size(0), 1 + tokens.size(1)

# embeddings
x = self.scale * self.token_embedding(tokens)
x = torch.cat([x, self.eos_embedding.repeat(b, 1, 1)], dim=1)
# x = x + self.pos_embedding(tokens)
x = self.dropout(x)

# attention mask
if mask is None:
mask = self.attn_mask[:, :, :seq_len, :seq_len].repeat(b, 1, 1, 1)
else:
mask = self.attn_mask[:, :, :seq_len, :seq_len] * torch.cat(
[mask, torch.zeros_like(mask[:, :1])], dim=1).view(
b, 1, 1, seq_len)

# transformer
for block in self.blocks:
x = block(x, mask)
x = self.norm(x)

# head
logits = self.head(x[:, -1])
return logits

def init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
if m.padding_idx is not None:
nn.init.zeros_(m.weight[m.padding_idx])

+ 5
- 2
modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py View File

@@ -3,7 +3,8 @@ from typing import Any, Dict, Optional
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForTextToImageSynthesis
from modelscope.models.multi_modal import (
MultiStageDiffusionForTextToImageSynthesis, OfaForTextToImageSynthesis)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
@@ -48,7 +49,9 @@ class TextToImageSynthesisPipeline(Pipeline):
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(self.model, OfaForTextToImageSynthesis):
if isinstance(self.model,
(OfaForTextToImageSynthesis,
MultiStageDiffusionForTextToImageSynthesis)):
return self.model(input)
return self.model.generate(input)



+ 40
- 0
tests/pipelines/test_multi_stage_diffusion.py View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

import numpy as np
import torch

from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class MultiStageDiffusionTest(unittest.TestCase):
model_id = 'damo/cv_diffusion_text-to-image-synthesis'
test_text = {'text': 'Photograph of a baby chicken wearing sunglasses'}

@unittest.skip(
'skip test since the pretrained model is not publicly available')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=model)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))

@unittest.skip(
'skip test since the pretrained model is not publicly available')
def test_run_with_model_name(self):
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=self.model_id)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save