Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10037492master
| @@ -72,6 +72,7 @@ class Models(object): | |||||
| gemm = 'gemm-generative-multi-modal' | gemm = 'gemm-generative-multi-modal' | ||||
| mplug = 'mplug' | mplug = 'mplug' | ||||
| diffusion = 'diffusion-text-to-image-synthesis' | diffusion = 'diffusion-text-to-image-synthesis' | ||||
| multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis' | |||||
| team = 'team-multi-modal-similarity' | team = 'team-multi-modal-similarity' | ||||
| video_clip = 'video-clip-multi-modal-embedding' | video_clip = 'video-clip-multi-modal-embedding' | ||||
| @@ -14,6 +14,8 @@ if TYPE_CHECKING: | |||||
| from .ofa_for_all_tasks import OfaForAllTasks | from .ofa_for_all_tasks import OfaForAllTasks | ||||
| from .ofa_for_text_to_image_synthesis_model import \ | from .ofa_for_text_to_image_synthesis_model import \ | ||||
| OfaForTextToImageSynthesis | OfaForTextToImageSynthesis | ||||
| from .multi_stage_diffusion import \ | |||||
| MultiStageDiffusionForTextToImageSynthesis | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -25,7 +27,9 @@ else: | |||||
| 'mplug_for_all_tasks': ['MPlugForAllTasks'], | 'mplug_for_all_tasks': ['MPlugForAllTasks'], | ||||
| 'ofa_for_all_tasks': ['OfaForAllTasks'], | 'ofa_for_all_tasks': ['OfaForAllTasks'], | ||||
| 'ofa_for_text_to_image_synthesis_model': | 'ofa_for_text_to_image_synthesis_model': | ||||
| ['OfaForTextToImageSynthesis'] | |||||
| ['OfaForTextToImageSynthesis'], | |||||
| 'multi_stage_diffusion': | |||||
| ['MultiStageDiffusionForTextToImageSynthesis'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1 @@ | |||||
| from .model import MultiStageDiffusionForTextToImageSynthesis | |||||
| @@ -0,0 +1,318 @@ | |||||
| # The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP. | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['CLIP'] | |||||
| def to_fp16(m): | |||||
| if isinstance(m, (nn.Linear, nn.Conv2d)): | |||||
| m.weight.data = m.weight.data.half() | |||||
| if m.bias is not None: | |||||
| m.bias.data = m.bias.data.half() | |||||
| elif hasattr(m, 'head'): | |||||
| p = getattr(m, 'head') | |||||
| p.data = p.data.half() | |||||
| class QuickGELU(nn.Module): | |||||
| def forward(self, x): | |||||
| return x * torch.sigmoid(1.702 * x) | |||||
| class LayerNorm(nn.LayerNorm): | |||||
| r"""Subclass of nn.LayerNorm to handle fp16. | |||||
| """ | |||||
| def forward(self, x): | |||||
| return super(LayerNorm, self).forward(x.float()).type_as(x) | |||||
| class SelfAttention(nn.Module): | |||||
| def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): | |||||
| assert dim % num_heads == 0 | |||||
| super(SelfAttention, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = dim // num_heads | |||||
| self.scale = 1.0 / math.sqrt(self.head_dim) | |||||
| # layers | |||||
| self.to_qkv = nn.Linear(dim, dim * 3) | |||||
| self.attn_dropout = nn.Dropout(attn_dropout) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| self.proj_dropout = nn.Dropout(proj_dropout) | |||||
| def forward(self, x, mask=None): | |||||
| r"""x: [B, L, C]. | |||||
| mask: [*, L, L]. | |||||
| """ | |||||
| b, l, _, n = *x.size(), self.num_heads | |||||
| # compute query, key, and value | |||||
| q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) | |||||
| q = q.reshape(l, b * n, -1).transpose(0, 1) | |||||
| k = k.reshape(l, b * n, -1).transpose(0, 1) | |||||
| v = v.reshape(l, b * n, -1).transpose(0, 1) | |||||
| # compute attention | |||||
| attn = self.scale * torch.bmm(q, k.transpose(1, 2)) | |||||
| if mask is not None: | |||||
| attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) | |||||
| attn = F.softmax(attn.float(), dim=-1).type_as(attn) | |||||
| attn = self.attn_dropout(attn) | |||||
| # gather context | |||||
| x = torch.bmm(attn, v) | |||||
| x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) | |||||
| # output | |||||
| x = self.proj(x) | |||||
| x = self.proj_dropout(x) | |||||
| return x | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| # layers | |||||
| self.norm1 = LayerNorm(dim) | |||||
| self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) | |||||
| self.norm2 = LayerNorm(dim) | |||||
| self.mlp = nn.Sequential( | |||||
| nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), | |||||
| nn.Dropout(proj_dropout)) | |||||
| def forward(self, x, mask=None): | |||||
| x = x + self.attn(self.norm1(x), mask) | |||||
| x = x + self.mlp(self.norm2(x)) | |||||
| return x | |||||
| class VisionTransformer(nn.Module): | |||||
| def __init__(self, | |||||
| image_size=224, | |||||
| patch_size=16, | |||||
| dim=768, | |||||
| out_dim=512, | |||||
| num_heads=12, | |||||
| num_layers=12, | |||||
| attn_dropout=0.0, | |||||
| proj_dropout=0.0, | |||||
| embedding_dropout=0.0): | |||||
| assert image_size % patch_size == 0 | |||||
| super(VisionTransformer, self).__init__() | |||||
| self.image_size = image_size | |||||
| self.patch_size = patch_size | |||||
| self.dim = dim | |||||
| self.out_dim = out_dim | |||||
| self.num_heads = num_heads | |||||
| self.num_layers = num_layers | |||||
| self.num_patches = (image_size // patch_size)**2 | |||||
| # embeddings | |||||
| gain = 1.0 / math.sqrt(dim) | |||||
| self.patch_embedding = nn.Conv2d( | |||||
| 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) | |||||
| self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) | |||||
| self.pos_embedding = nn.Parameter( | |||||
| gain * torch.randn(1, self.num_patches + 1, dim)) | |||||
| self.dropout = nn.Dropout(embedding_dropout) | |||||
| # transformer | |||||
| self.pre_norm = LayerNorm(dim) | |||||
| self.transformer = nn.Sequential(*[ | |||||
| AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) | |||||
| for _ in range(num_layers) | |||||
| ]) | |||||
| self.post_norm = LayerNorm(dim) | |||||
| # head | |||||
| self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) | |||||
| def forward(self, x): | |||||
| b, dtype = x.size(0), self.head.dtype | |||||
| x = x.type(dtype) | |||||
| # patch-embedding | |||||
| x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] | |||||
| x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], | |||||
| dim=1) | |||||
| x = self.dropout(x + self.pos_embedding.type(dtype)) | |||||
| x = self.pre_norm(x) | |||||
| # transformer | |||||
| x = self.transformer(x) | |||||
| # head | |||||
| x = self.post_norm(x) | |||||
| x = torch.mm(x[:, 0, :], self.head) | |||||
| return x | |||||
| def fp16(self): | |||||
| return self.apply(to_fp16) | |||||
| class TextTransformer(nn.Module): | |||||
| def __init__(self, | |||||
| vocab_size, | |||||
| text_len, | |||||
| dim=512, | |||||
| out_dim=512, | |||||
| num_heads=8, | |||||
| num_layers=12, | |||||
| attn_dropout=0.0, | |||||
| proj_dropout=0.0, | |||||
| embedding_dropout=0.0): | |||||
| super(TextTransformer, self).__init__() | |||||
| self.vocab_size = vocab_size | |||||
| self.text_len = text_len | |||||
| self.dim = dim | |||||
| self.out_dim = out_dim | |||||
| self.num_heads = num_heads | |||||
| self.num_layers = num_layers | |||||
| # embeddings | |||||
| self.token_embedding = nn.Embedding(vocab_size, dim) | |||||
| self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) | |||||
| self.dropout = nn.Dropout(embedding_dropout) | |||||
| # transformer | |||||
| self.transformer = nn.ModuleList([ | |||||
| AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) | |||||
| for _ in range(num_layers) | |||||
| ]) | |||||
| self.norm = LayerNorm(dim) | |||||
| # head | |||||
| gain = 1.0 / math.sqrt(dim) | |||||
| self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) | |||||
| # causal attention mask | |||||
| self.register_buffer('attn_mask', | |||||
| torch.tril(torch.ones(1, text_len, text_len))) | |||||
| def forward(self, x): | |||||
| eot, dtype = x.argmax(dim=-1), self.head.dtype | |||||
| # embeddings | |||||
| x = self.dropout( | |||||
| self.token_embedding(x).type(dtype) | |||||
| + self.pos_embedding.type(dtype)) | |||||
| # transformer | |||||
| for block in self.transformer: | |||||
| x = block(x, self.attn_mask) | |||||
| # head | |||||
| x = self.norm(x) | |||||
| x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) | |||||
| return x | |||||
| def fp16(self): | |||||
| return self.apply(to_fp16) | |||||
| class CLIP(nn.Module): | |||||
| def __init__(self, | |||||
| embed_dim=512, | |||||
| image_size=224, | |||||
| patch_size=16, | |||||
| vision_dim=768, | |||||
| vision_heads=12, | |||||
| vision_layers=12, | |||||
| vocab_size=49408, | |||||
| text_len=77, | |||||
| text_dim=512, | |||||
| text_heads=8, | |||||
| text_layers=12, | |||||
| attn_dropout=0.0, | |||||
| proj_dropout=0.0, | |||||
| embedding_dropout=0.0): | |||||
| super(CLIP, self).__init__() | |||||
| self.embed_dim = embed_dim | |||||
| self.image_size = image_size | |||||
| self.patch_size = patch_size | |||||
| self.vision_dim = vision_dim | |||||
| self.vision_heads = vision_heads | |||||
| self.vision_layers = vision_layers | |||||
| self.vocab_size = vocab_size | |||||
| self.text_len = text_len | |||||
| self.text_dim = text_dim | |||||
| self.text_heads = text_heads | |||||
| self.text_layers = text_layers | |||||
| # models | |||||
| self.visual = VisionTransformer( | |||||
| image_size=image_size, | |||||
| patch_size=patch_size, | |||||
| dim=vision_dim, | |||||
| out_dim=embed_dim, | |||||
| num_heads=vision_heads, | |||||
| num_layers=vision_layers, | |||||
| attn_dropout=attn_dropout, | |||||
| proj_dropout=proj_dropout, | |||||
| embedding_dropout=embedding_dropout) | |||||
| self.textual = TextTransformer( | |||||
| vocab_size=vocab_size, | |||||
| text_len=text_len, | |||||
| dim=text_dim, | |||||
| out_dim=embed_dim, | |||||
| num_heads=text_heads, | |||||
| num_layers=text_layers, | |||||
| attn_dropout=attn_dropout, | |||||
| proj_dropout=proj_dropout, | |||||
| embedding_dropout=embedding_dropout) | |||||
| self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) | |||||
| def forward(self, imgs, txt_tokens): | |||||
| r"""imgs: [B, C, H, W] of torch.float32. | |||||
| txt_tokens: [B, T] of torch.long. | |||||
| """ | |||||
| xi = self.visual(imgs) | |||||
| xt = self.textual(txt_tokens) | |||||
| # normalize features | |||||
| xi = F.normalize(xi, p=2, dim=1) | |||||
| xt = F.normalize(xt, p=2, dim=1) | |||||
| # logits | |||||
| scale = self.log_scale.exp() | |||||
| logits_i2t = scale * torch.mm(xi, xt.t()) | |||||
| logits_t2i = scale * torch.mm(xt, xi.t()) | |||||
| return logits_i2t, logits_t2i | |||||
| def init_weights(self): | |||||
| # embeddings | |||||
| nn.init.normal_(self.textual.token_embedding.weight, std=0.02) | |||||
| nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) | |||||
| # attentions | |||||
| for modality in ['visual', 'textual']: | |||||
| dim = self.vision_dim if modality == 'visual' else 'textual' | |||||
| transformer = getattr(self, modality).transformer | |||||
| proj_gain = (1.0 / math.sqrt(dim)) * ( | |||||
| 1.0 / math.sqrt(2 * transformer.num_layers)) | |||||
| attn_gain = 1.0 / math.sqrt(dim) | |||||
| mlp_gain = 1.0 / math.sqrt(2.0 * dim) | |||||
| for block in transformer.layers: | |||||
| nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) | |||||
| nn.init.normal_(block.attn.proj.weight, std=proj_gain) | |||||
| nn.init.normal_(block.mlp[0].weight, std=mlp_gain) | |||||
| nn.init.normal_(block.mlp[2].weight, std=proj_gain) | |||||
| def fp16(self): | |||||
| return self.apply(to_fp16) | |||||
| @@ -0,0 +1,322 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['Decoder'] | |||||
| def sinusoidal_embedding(timesteps, dim): | |||||
| # check input | |||||
| half = dim // 2 | |||||
| timesteps = timesteps.float() | |||||
| # compute sinusoidal embedding | |||||
| sinusoid = torch.outer( | |||||
| timesteps, torch.pow(10000, | |||||
| -torch.arange(half).to(timesteps).div(half))) | |||||
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |||||
| if dim % 2 != 0: | |||||
| x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |||||
| return x | |||||
| class Resample(nn.Module): | |||||
| def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): | |||||
| assert scale_factor in [0.5, 1.0, 2.0] | |||||
| super(Resample, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.out_dim = out_dim | |||||
| self.scale_factor = scale_factor | |||||
| self.use_conv = use_conv | |||||
| # layers | |||||
| if scale_factor == 2.0: | |||||
| self.resample = nn.Sequential( | |||||
| nn.Upsample(scale_factor=scale_factor, mode='nearest'), | |||||
| nn.Conv2d(in_dim, out_dim, 3, padding=1) | |||||
| if use_conv else nn.Identity()) | |||||
| elif scale_factor == 0.5: | |||||
| self.resample = nn.Conv2d( | |||||
| in_dim, out_dim, 3, stride=2, | |||||
| padding=1) if use_conv else nn.AvgPool2d( | |||||
| kernel_size=2, stride=2) | |||||
| else: | |||||
| self.resample = nn.Identity() | |||||
| def forward(self, x): | |||||
| return self.resample(x) | |||||
| class ResidualBlock(nn.Module): | |||||
| def __init__(self, | |||||
| in_dim, | |||||
| embed_dim, | |||||
| out_dim, | |||||
| use_scale_shift_norm=True, | |||||
| scale_factor=1.0, | |||||
| dropout=0.0): | |||||
| super(ResidualBlock, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.embed_dim = embed_dim | |||||
| self.out_dim = out_dim | |||||
| self.use_scale_shift_norm = use_scale_shift_norm | |||||
| self.scale_factor = scale_factor | |||||
| # layers | |||||
| self.layer1 = nn.Sequential( | |||||
| nn.GroupNorm(32, in_dim), nn.SiLU(), | |||||
| nn.Conv2d(in_dim, out_dim, 3, padding=1)) | |||||
| self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) | |||||
| self.embedding = nn.Sequential( | |||||
| nn.SiLU(), | |||||
| nn.Linear(embed_dim, | |||||
| out_dim * 2 if use_scale_shift_norm else out_dim)) | |||||
| self.layer2 = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |||||
| nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |||||
| self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( | |||||
| in_dim, out_dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.layer2[-1].weight) | |||||
| def forward(self, x, e): | |||||
| identity = self.resample(x) | |||||
| x = self.layer1[-1](self.resample(self.layer1[:-1](x))) | |||||
| e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) | |||||
| if self.use_scale_shift_norm: | |||||
| scale, shift = e.chunk(2, dim=1) | |||||
| x = self.layer2[0](x) * (1 + scale) + shift | |||||
| x = self.layer2[1:](x) | |||||
| else: | |||||
| x = x + e | |||||
| x = self.layer2(x) | |||||
| x = x + self.shortcut(identity) | |||||
| return x | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): | |||||
| # consider head_dim first, then num_heads | |||||
| num_heads = dim // head_dim if head_dim else num_heads | |||||
| head_dim = dim // num_heads | |||||
| assert num_heads * head_dim == dim | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.context_dim = context_dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = head_dim | |||||
| self.scale = math.pow(head_dim, -0.25) | |||||
| # layers | |||||
| self.norm = nn.GroupNorm(32, dim) | |||||
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |||||
| if context_dim is not None: | |||||
| self.context_kv = nn.Linear(context_dim, dim * 2) | |||||
| self.proj = nn.Conv2d(dim, dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.proj.weight) | |||||
| def forward(self, x, context=None): | |||||
| r"""x: [B, C, H, W]. | |||||
| context: [B, L, C] or None. | |||||
| """ | |||||
| identity = x | |||||
| b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim | |||||
| # compute query, key, value | |||||
| x = self.norm(x) | |||||
| q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) | |||||
| if context is not None: | |||||
| ck, cv = self.context_kv(context).reshape(b, -1, n * 2, | |||||
| d).permute(0, 2, 3, | |||||
| 1).chunk( | |||||
| 2, dim=1) | |||||
| k = torch.cat([ck, k], dim=-1) | |||||
| v = torch.cat([cv, v], dim=-1) | |||||
| # compute attention | |||||
| attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) | |||||
| attn = F.softmax(attn, dim=-1) | |||||
| # gather context | |||||
| x = torch.matmul(v, attn.transpose(-1, -2)) | |||||
| x = x.reshape(b, c, h, w) | |||||
| # output | |||||
| x = self.proj(x) | |||||
| return x + identity | |||||
| class Decoder(nn.Module): | |||||
| def __init__(self, | |||||
| in_dim=3, | |||||
| dim=512, | |||||
| y_dim=512, | |||||
| context_dim=512, | |||||
| out_dim=6, | |||||
| dim_mult=[1, 2, 3, 4], | |||||
| num_heads=None, | |||||
| head_dim=64, | |||||
| num_res_blocks=3, | |||||
| attn_scales=[1 / 2, 1 / 4, 1 / 8], | |||||
| resblock_resample=True, | |||||
| use_scale_shift_norm=True, | |||||
| dropout=0.1): | |||||
| embed_dim = dim * 4 | |||||
| super(Decoder, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.dim = dim | |||||
| self.y_dim = y_dim | |||||
| self.context_dim = context_dim | |||||
| self.embed_dim = embed_dim | |||||
| self.out_dim = out_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = head_dim | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| self.resblock_resample = resblock_resample | |||||
| self.use_scale_shift_norm = use_scale_shift_norm | |||||
| # params | |||||
| enc_dims = [dim * u for u in [1] + dim_mult] | |||||
| dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||||
| shortcut_dims = [] | |||||
| scale = 1.0 | |||||
| # embeddings | |||||
| self.time_embedding = nn.Sequential( | |||||
| nn.Linear(dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| self.y_embedding = nn.Sequential( | |||||
| nn.Linear(y_dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| self.context_embedding = nn.Sequential( | |||||
| nn.Linear(y_dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, context_dim * 4)) | |||||
| # encoder | |||||
| self.encoder = nn.ModuleList( | |||||
| [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |||||
| shortcut_dims.append(dim) | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |||||
| for j in range(num_res_blocks): | |||||
| # residual (+attention) blocks | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 1.0, dropout) | |||||
| ]) | |||||
| if scale in attn_scales: | |||||
| block.append( | |||||
| AttentionBlock(out_dim, context_dim, num_heads, | |||||
| head_dim)) | |||||
| in_dim = out_dim | |||||
| self.encoder.append(block) | |||||
| shortcut_dims.append(out_dim) | |||||
| # downsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |||||
| if resblock_resample: | |||||
| downsample = ResidualBlock(out_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 0.5, | |||||
| dropout) | |||||
| else: | |||||
| downsample = Resample( | |||||
| out_dim, out_dim, 0.5, use_conv=True) | |||||
| shortcut_dims.append(out_dim) | |||||
| scale /= 2.0 | |||||
| self.encoder.append(downsample) | |||||
| # middle | |||||
| self.middle = nn.ModuleList([ | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, | |||||
| 1.0, dropout), | |||||
| AttentionBlock(out_dim, context_dim, num_heads, head_dim), | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, | |||||
| 1.0, dropout) | |||||
| ]) | |||||
| # decoder | |||||
| self.decoder = nn.ModuleList() | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): | |||||
| for j in range(num_res_blocks + 1): | |||||
| # residual (+attention) blocks | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, | |||||
| out_dim, use_scale_shift_norm, 1.0, dropout) | |||||
| ]) | |||||
| if scale in attn_scales: | |||||
| block.append( | |||||
| AttentionBlock(out_dim, context_dim, num_heads, | |||||
| head_dim)) | |||||
| in_dim = out_dim | |||||
| # upsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks: | |||||
| if resblock_resample: | |||||
| upsample = ResidualBlock(out_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 2.0, | |||||
| dropout) | |||||
| else: | |||||
| upsample = Resample( | |||||
| out_dim, out_dim, 2.0, use_conv=True) | |||||
| scale *= 2.0 | |||||
| block.append(upsample) | |||||
| self.decoder.append(block) | |||||
| # head | |||||
| self.head = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), | |||||
| nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.head[-1].weight) | |||||
| def forward(self, x, t, y): | |||||
| # embeddings | |||||
| e = self.time_embedding(sinusoidal_embedding( | |||||
| t, self.dim)) + self.y_embedding(y) | |||||
| context = self.context_embedding(y).view(-1, 4, self.context_dim) | |||||
| # encoder | |||||
| xs = [] | |||||
| for block in self.encoder: | |||||
| x = self._forward_single(block, x, e, context) | |||||
| xs.append(x) | |||||
| # middle | |||||
| for block in self.middle: | |||||
| x = self._forward_single(block, x, e, context) | |||||
| # decoder | |||||
| for block in self.decoder: | |||||
| x = torch.cat([x, xs.pop()], dim=1) | |||||
| x = self._forward_single(block, x, e, context) | |||||
| # head | |||||
| x = self.head(x) | |||||
| return x | |||||
| def _forward_single(self, module, x, e, context): | |||||
| if isinstance(module, ResidualBlock): | |||||
| x = module(x, e) | |||||
| elif isinstance(module, AttentionBlock): | |||||
| x = module(x, context) | |||||
| elif isinstance(module, nn.ModuleList): | |||||
| for block in module: | |||||
| x = self._forward_single(block, x, e, context) | |||||
| else: | |||||
| x = module(x) | |||||
| return x | |||||
| @@ -0,0 +1,641 @@ | |||||
| # The implementation here is modified based on latent diffusion, publicly available | |||||
| # at https://github.com/CompVis/latent-diffusion. | |||||
| import math | |||||
| import torch | |||||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | |||||
| def kl_divergence(mu1, logvar1, mu2, logvar2): | |||||
| u1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) | |||||
| u2 = ((mu1 - mu2)**2) * torch.exp(-logvar2) | |||||
| return 0.5 * (u1 + u2) | |||||
| def standard_normal_cdf(x): | |||||
| r"""A fast approximation of the cumulative distribution function of the standard normal. | |||||
| """ | |||||
| return 0.5 * (1.0 + torch.tanh( | |||||
| math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |||||
| def discretized_gaussian_log_likelihood(x0, mean, log_scale): | |||||
| assert x0.shape == mean.shape == log_scale.shape | |||||
| cx = x0 - mean | |||||
| inv_stdv = torch.exp(-log_scale) | |||||
| cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) | |||||
| cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) | |||||
| log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |||||
| log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |||||
| cdf_delta = cdf_plus - cdf_min | |||||
| log_probs = torch.where( | |||||
| x0 < -0.999, log_cdf_plus, | |||||
| torch.where(x0 > 0.999, log_one_minus_cdf_min, | |||||
| torch.log(cdf_delta.clamp(min=1e-12)))) | |||||
| assert log_probs.shape == x0.shape | |||||
| return log_probs | |||||
| def _i(tensor, t, x): | |||||
| r"""Index tensor using t and format the output according to x. | |||||
| """ | |||||
| shape = (x.size(0), ) + (1, ) * (x.ndim - 1) | |||||
| return tensor[t].view(shape).to(x) | |||||
| def beta_schedule(schedule, | |||||
| num_timesteps=1000, | |||||
| init_beta=None, | |||||
| last_beta=None): | |||||
| if schedule == 'linear': | |||||
| scale = 1000.0 / num_timesteps | |||||
| init_beta = init_beta or scale * 0.0001 | |||||
| last_beta = last_beta or scale * 0.02 | |||||
| return torch.linspace( | |||||
| init_beta, last_beta, num_timesteps, dtype=torch.float64) | |||||
| elif schedule == 'quadratic': | |||||
| init_beta = init_beta or 0.0015 | |||||
| last_beta = last_beta or 0.0195 | |||||
| return torch.linspace( | |||||
| init_beta**0.5, last_beta**0.5, num_timesteps, | |||||
| dtype=torch.float64)**2 | |||||
| elif schedule == 'cosine': | |||||
| betas = [] | |||||
| for step in range(num_timesteps): | |||||
| t1 = step / num_timesteps | |||||
| t2 = (step + 1) / num_timesteps | |||||
| fn_t1 = math.cos((t1 + 0.008) / 1.008 * math.pi / 2)**2 | |||||
| fn_t2 = math.cos((t2 + 0.008) / 1.008 * math.pi / 2)**2 | |||||
| betas.append(min(1.0 - fn_t2 / fn_t1, 0.999)) | |||||
| return torch.tensor(betas, dtype=torch.float64) | |||||
| else: | |||||
| raise ValueError(f'Unsupported schedule: {schedule}') | |||||
| class GaussianDiffusion(object): | |||||
| def __init__(self, | |||||
| betas, | |||||
| mean_type='eps', | |||||
| var_type='learned_range', | |||||
| loss_type='mse', | |||||
| rescale_timesteps=False): | |||||
| # check input | |||||
| if not isinstance(betas, torch.DoubleTensor): | |||||
| betas = torch.tensor(betas, dtype=torch.float64) | |||||
| assert min(betas) > 0 and max(betas) <= 1 | |||||
| assert mean_type in ['x0', 'x_{t-1}', 'eps'] | |||||
| assert var_type in [ | |||||
| 'learned', 'learned_range', 'fixed_large', 'fixed_small' | |||||
| ] | |||||
| assert loss_type in [ | |||||
| 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' | |||||
| ] | |||||
| self.betas = betas | |||||
| self.num_timesteps = len(betas) | |||||
| self.mean_type = mean_type | |||||
| self.var_type = var_type | |||||
| self.loss_type = loss_type | |||||
| self.rescale_timesteps = rescale_timesteps | |||||
| # alphas | |||||
| alphas = 1 - self.betas | |||||
| self.alphas_cumprod = torch.cumprod(alphas, dim=0) | |||||
| self.alphas_cumprod_prev = torch.cat( | |||||
| [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) | |||||
| self.alphas_cumprod_next = torch.cat( | |||||
| [self.alphas_cumprod[1:], | |||||
| alphas.new_zeros([1])]) | |||||
| # q(x_t | x_{t-1}) | |||||
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |||||
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 | |||||
| - self.alphas_cumprod) | |||||
| self.log_one_minus_alphas_cumprod = torch.log(1.0 | |||||
| - self.alphas_cumprod) | |||||
| self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) | |||||
| self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod | |||||
| - 1) | |||||
| # q(x_{t-1} | x_t, x_0) | |||||
| self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( | |||||
| 1.0 - self.alphas_cumprod) | |||||
| self.posterior_log_variance_clipped = torch.log( | |||||
| self.posterior_variance.clamp(1e-20)) | |||||
| self.posterior_mean_coef1 = betas * torch.sqrt( | |||||
| self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |||||
| self.posterior_mean_coef2 = ( | |||||
| 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( | |||||
| 1.0 - self.alphas_cumprod) | |||||
| def q_sample(self, x0, t, noise=None): | |||||
| r"""Sample from q(x_t | x_0). | |||||
| """ | |||||
| noise = torch.randn_like(x0) if noise is None else noise | |||||
| u1 = _i(self.sqrt_alphas_cumprod, t, x0) * x0 | |||||
| u2 = _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise | |||||
| return u1 + u2 | |||||
| def q_mean_variance(self, x0, t): | |||||
| r"""Distribution of q(x_t | x_0). | |||||
| """ | |||||
| mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 | |||||
| var = _i(1.0 - self.alphas_cumprod, t, x0) | |||||
| log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) | |||||
| return mu, var, log_var | |||||
| def q_posterior_mean_variance(self, x0, xt, t): | |||||
| r"""Distribution of q(x_{t-1} | x_t, x_0). | |||||
| """ | |||||
| mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( | |||||
| self.posterior_mean_coef2, t, xt) * xt | |||||
| var = _i(self.posterior_variance, t, xt) | |||||
| log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||||
| return mu, var, log_var | |||||
| @torch.no_grad() | |||||
| def p_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None): | |||||
| r"""Sample from p(x_{t-1} | x_t). | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| """ | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||||
| clamp, percentile, | |||||
| guide_scale) | |||||
| # random sample (with optional conditional function) | |||||
| noise = torch.randn_like(xt) | |||||
| shape = (-1, *((1, ) * (xt.ndim - 1))) | |||||
| mask = t.ne(0).float().view(shape) # no noise when t == 0 | |||||
| if condition_fn is not None: | |||||
| grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) | |||||
| mu = mu.float() + var * grad.float() | |||||
| xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise | |||||
| return xt_1, x0 | |||||
| @torch.no_grad() | |||||
| def p_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None): | |||||
| r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). | |||||
| """ | |||||
| # prepare input | |||||
| b = noise.size(0) | |||||
| xt = noise | |||||
| # diffusion process | |||||
| for step in torch.arange(self.num_timesteps).flip(0): | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, condition_fn, guide_scale) | |||||
| return xt | |||||
| def p_mean_variance(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| guide_scale=None): | |||||
| r"""Distribution of p(x_{t-1} | x_t). | |||||
| """ | |||||
| # predict distribution | |||||
| if guide_scale is None: | |||||
| out = model(xt, self._scale_timesteps(t), **model_kwargs) | |||||
| else: | |||||
| # classifier-free guidance | |||||
| # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) | |||||
| assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 | |||||
| y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) | |||||
| u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) | |||||
| cond = self.var_type.startswith('fixed') | |||||
| dim = y_out.size(1) if cond else y_out.size(1) // 2 | |||||
| u1 = u_out[:, :dim] | |||||
| u2 = guide_scale * (y_out[:, :dim] - u_out[:, :dim]) | |||||
| out = torch.cat([u1 + u2, y_out[:, dim:]], dim=1) | |||||
| # compute variance | |||||
| if self.var_type == 'learned': | |||||
| out, log_var = out.chunk(2, dim=1) | |||||
| var = torch.exp(log_var) | |||||
| elif self.var_type == 'learned_range': | |||||
| out, fraction = out.chunk(2, dim=1) | |||||
| min_log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||||
| max_log_var = _i(torch.log(self.betas), t, xt) | |||||
| fraction = (fraction + 1) / 2.0 | |||||
| log_var = fraction * max_log_var + (1 - fraction) * min_log_var | |||||
| var = torch.exp(log_var) | |||||
| elif self.var_type == 'fixed_large': | |||||
| var = _i( | |||||
| torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, | |||||
| xt) | |||||
| log_var = torch.log(var) | |||||
| elif self.var_type == 'fixed_small': | |||||
| var = _i(self.posterior_variance, t, xt) | |||||
| log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||||
| # compute mean and x0 | |||||
| if self.mean_type == 'x_{t-1}': | |||||
| mu = out # x_{t-1} | |||||
| u1 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu | |||||
| u2 = _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, | |||||
| xt) * xt | |||||
| x0 = u1 - u2 | |||||
| elif self.mean_type == 'x0': | |||||
| x0 = out | |||||
| mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) | |||||
| elif self.mean_type == 'eps': | |||||
| u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out | |||||
| x0 = u1 - u2 | |||||
| mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) | |||||
| # restrict the range of x0 | |||||
| if percentile is not None: | |||||
| assert percentile > 0 and percentile <= 1 # e.g., 0.995 | |||||
| s = torch.quantile( | |||||
| x0.flatten(1).abs(), percentile, | |||||
| dim=1).clamp_(1.0).view(-1, 1, 1, 1) | |||||
| x0 = torch.min(s, torch.max(-s, x0)) / s | |||||
| elif clamp is not None: | |||||
| x0 = x0.clamp(-clamp, clamp) | |||||
| return mu, var, log_var, x0 | |||||
| @torch.no_grad() | |||||
| def ddim_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20, | |||||
| eta=0.0): | |||||
| r"""Sample from p(x_{t-1} | x_t) using DDIM. | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| """ | |||||
| stride = self.num_timesteps // ddim_timesteps | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, | |||||
| percentile, guide_scale) | |||||
| if condition_fn is not None: | |||||
| # x0 -> eps | |||||
| alpha = _i(self.alphas_cumprod, t, xt) | |||||
| u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = u1 / u2 | |||||
| eps = eps - (1 - alpha).sqrt() * condition_fn( | |||||
| xt, self._scale_timesteps(t), **model_kwargs) | |||||
| # eps -> x0 | |||||
| u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||||
| x0 = u1 - u2 | |||||
| # derive variables | |||||
| u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = u1 / u2 | |||||
| alphas = _i(self.alphas_cumprod, t, xt) | |||||
| alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) | |||||
| u1 = (1 - alphas_prev) / (1 - alphas) | |||||
| u2 = (1 - alphas / alphas_prev) | |||||
| sigmas = eta * torch.sqrt(u1 * u2) | |||||
| # random sample | |||||
| noise = torch.randn_like(xt) | |||||
| direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps | |||||
| mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||||
| xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise | |||||
| return xt_1, x0 | |||||
| @torch.no_grad() | |||||
| def ddim_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20, | |||||
| eta=0.0): | |||||
| # prepare input | |||||
| b = noise.size(0) | |||||
| xt = noise | |||||
| # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) | |||||
| steps = (1 + torch.arange(0, self.num_timesteps, | |||||
| self.num_timesteps // ddim_timesteps)).clamp( | |||||
| 0, self.num_timesteps - 1).flip(0) | |||||
| for step in steps: | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, condition_fn, guide_scale, | |||||
| ddim_timesteps, eta) | |||||
| return xt | |||||
| @torch.no_grad() | |||||
| def ddim_reverse_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20): | |||||
| r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). | |||||
| """ | |||||
| stride = self.num_timesteps // ddim_timesteps | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, | |||||
| percentile, guide_scale) | |||||
| # derive variables | |||||
| u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = u1 / u2 | |||||
| alphas_next = _i( | |||||
| torch.cat( | |||||
| [self.alphas_cumprod, | |||||
| self.alphas_cumprod.new_zeros([1])]), | |||||
| (t + stride).clamp(0, self.num_timesteps), xt) | |||||
| # reverse sample | |||||
| mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps | |||||
| return mu, x0 | |||||
| @torch.no_grad() | |||||
| def ddim_reverse_sample_loop(self, | |||||
| x0, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20): | |||||
| # prepare input | |||||
| b = x0.size(0) | |||||
| xt = x0 | |||||
| # reconstruction steps | |||||
| steps = torch.arange(0, self.num_timesteps, | |||||
| self.num_timesteps // ddim_timesteps) | |||||
| for step in steps: | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, guide_scale, | |||||
| ddim_timesteps) | |||||
| return xt | |||||
| @torch.no_grad() | |||||
| def plms_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| plms_timesteps=20): | |||||
| r"""Sample from p(x_{t-1} | x_t) using PLMS. | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| """ | |||||
| stride = self.num_timesteps // plms_timesteps | |||||
| # function for compute eps | |||||
| def compute_eps(xt, t): | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||||
| clamp, percentile, guide_scale) | |||||
| # condition | |||||
| if condition_fn is not None: | |||||
| # x0 -> eps | |||||
| alpha = _i(self.alphas_cumprod, t, xt) | |||||
| u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = u1 / u2 | |||||
| eps = eps - (1 - alpha).sqrt() * condition_fn( | |||||
| xt, self._scale_timesteps(t), **model_kwargs) | |||||
| # eps -> x0 | |||||
| u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||||
| x0 = u1 - u2 | |||||
| # derive eps | |||||
| u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = u1 / u2 | |||||
| return eps | |||||
| # function for compute x_0 and x_{t-1} | |||||
| def compute_x0(eps, t): | |||||
| # eps -> x0 | |||||
| u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||||
| x0 = u1 - u2 | |||||
| # deterministic sample | |||||
| alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) | |||||
| direction = torch.sqrt(1 - alphas_prev) * eps | |||||
| # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||||
| xt_1 = torch.sqrt(alphas_prev) * x0 + direction | |||||
| return xt_1, x0 | |||||
| # PLMS sample | |||||
| eps = compute_eps(xt, t) | |||||
| if len(eps_cache) == 0: | |||||
| # 2nd order pseudo improved Euler | |||||
| xt_1, x0 = compute_x0(eps, t) | |||||
| eps_next = compute_eps(xt_1, (t - stride).clamp(0)) | |||||
| eps_prime = (eps + eps_next) / 2.0 | |||||
| elif len(eps_cache) == 1: | |||||
| # 2nd order pseudo linear multistep (Adams-Bashforth) | |||||
| eps_prime = (3 * eps - eps_cache[-1]) / 2.0 | |||||
| elif len(eps_cache) == 2: | |||||
| # 3nd order pseudo linear multistep (Adams-Bashforth) | |||||
| eps_prime = (23 * eps - 16 * eps_cache[-1] | |||||
| + 5 * eps_cache[-2]) / 12.0 | |||||
| elif len(eps_cache) >= 3: | |||||
| # 4nd order pseudo linear multistep (Adams-Bashforth) | |||||
| eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] | |||||
| - 9 * eps_cache[-3]) / 24.0 | |||||
| xt_1, x0 = compute_x0(eps_prime, t) | |||||
| return xt_1, x0, eps | |||||
| @torch.no_grad() | |||||
| def plms_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| plms_timesteps=20): | |||||
| # prepare input | |||||
| b = noise.size(0) | |||||
| xt = noise | |||||
| # diffusion process | |||||
| steps = (1 + torch.arange(0, self.num_timesteps, | |||||
| self.num_timesteps // plms_timesteps)).clamp( | |||||
| 0, self.num_timesteps - 1).flip(0) | |||||
| eps_cache = [] | |||||
| for step in steps: | |||||
| # PLMS sampling step | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, condition_fn, | |||||
| guide_scale, plms_timesteps, | |||||
| eps_cache) | |||||
| # update eps cache | |||||
| eps_cache.append(eps) | |||||
| if len(eps_cache) >= 4: | |||||
| eps_cache.pop(0) | |||||
| return xt | |||||
| def loss(self, x0, t, model, model_kwargs={}, noise=None, input_x0=None): | |||||
| noise = torch.randn_like(x0) if noise is None else noise | |||||
| input_x0 = x0 if input_x0 is None else input_x0 | |||||
| xt = self.q_sample(input_x0, t, noise=noise) | |||||
| # compute loss | |||||
| if self.loss_type in ['kl', 'rescaled_kl']: | |||||
| loss, _ = self.variational_lower_bound(x0, xt, t, model, | |||||
| model_kwargs) | |||||
| if self.loss_type == 'rescaled_kl': | |||||
| loss = loss * self.num_timesteps | |||||
| elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: | |||||
| out = model(xt, self._scale_timesteps(t), **model_kwargs) | |||||
| # VLB for variation | |||||
| loss_vlb = 0.0 | |||||
| if self.var_type in ['learned', 'learned_range']: | |||||
| out, var = out.chunk(2, dim=1) | |||||
| frozen = torch.cat([ | |||||
| out.detach(), var | |||||
| ], dim=1) # learn var without affecting the prediction of mean | |||||
| loss_vlb, _ = self.variational_lower_bound( | |||||
| x0, xt, t, model=lambda *args, **kwargs: frozen) | |||||
| if self.loss_type.startswith('rescaled_'): | |||||
| loss_vlb = loss_vlb * self.num_timesteps / 1000.0 | |||||
| # MSE/L1 for x0/eps | |||||
| target = { | |||||
| 'eps': noise, | |||||
| 'x0': x0, | |||||
| 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] | |||||
| }[self.mean_type] | |||||
| loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 | |||||
| ).abs().flatten(1).mean(dim=1) | |||||
| # total loss | |||||
| loss = loss + loss_vlb | |||||
| return loss | |||||
| def variational_lower_bound(self, | |||||
| x0, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None): | |||||
| # compute groundtruth and predicted distributions | |||||
| mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) | |||||
| mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||||
| clamp, percentile) | |||||
| # compute KL loss | |||||
| kl = kl_divergence(mu1, log_var1, mu2, log_var2) | |||||
| kl = kl.flatten(1).mean(dim=1) / math.log(2.0) | |||||
| # compute discretized NLL loss (for p(x0 | x1) only) | |||||
| nll = -discretized_gaussian_log_likelihood( | |||||
| x0, mean=mu2, log_scale=0.5 * log_var2) | |||||
| nll = nll.flatten(1).mean(dim=1) / math.log(2.0) | |||||
| # NLL for p(x0 | x1) and KL otherwise | |||||
| vlb = torch.where(t == 0, nll, kl) | |||||
| return vlb, x0 | |||||
| @torch.no_grad() | |||||
| def variational_lower_bound_loop(self, | |||||
| x0, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None): | |||||
| r"""Compute the entire variational lower bound, measured in bits-per-dim. | |||||
| """ | |||||
| # prepare input and output | |||||
| b = x0.size(0) | |||||
| metrics = {'vlb': [], 'mse': [], 'x0_mse': []} | |||||
| # loop | |||||
| for step in torch.arange(self.num_timesteps).flip(0): | |||||
| # compute VLB | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=x0.device) | |||||
| noise = torch.randn_like(x0) | |||||
| xt = self.q_sample(x0, t, noise) | |||||
| vlb, pred_x0 = self.variational_lower_bound( | |||||
| x0, xt, t, model, model_kwargs, clamp, percentile) | |||||
| # predict eps from x0 | |||||
| u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) | |||||
| u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = u1 / u2 | |||||
| # collect metrics | |||||
| metrics['vlb'].append(vlb) | |||||
| metrics['x0_mse'].append( | |||||
| (pred_x0 - x0).square().flatten(1).mean(dim=1)) | |||||
| metrics['mse'].append( | |||||
| (eps - noise).square().flatten(1).mean(dim=1)) | |||||
| metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} | |||||
| # compute the prior KL term for VLB, measured in bits-per-dim | |||||
| mu, _, log_var = self.q_mean_variance(x0, t) | |||||
| kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), | |||||
| torch.zeros_like(log_var)) | |||||
| kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) | |||||
| # update metrics | |||||
| metrics['prior_bits_per_dim'] = kl_prior | |||||
| metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior | |||||
| return metrics | |||||
| def _scale_timesteps(self, t): | |||||
| if self.rescale_timesteps: | |||||
| return t.float() * 1000.0 / self.num_timesteps | |||||
| return t | |||||
| @@ -0,0 +1,265 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.cuda.amp as amp | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from PIL import Image | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.clip import CLIP | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.decoder import Decoder | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.gaussian_diffusion import ( | |||||
| GaussianDiffusion, beta_schedule) | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.prior import Prior | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.tokenizer import ( | |||||
| CLIPTokenizer, XGLMTokenizer) | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.upsampler import ( | |||||
| Upsampler256, Upsampler1024) | |||||
| from modelscope.models.multi_modal.multi_stage_diffusion.xglm import XGLM | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['MultiStageDiffusionForTextToImageSynthesis'] | |||||
| def make_diffusion(schedule, | |||||
| num_timesteps=1000, | |||||
| init_beta=None, | |||||
| last_beta=None, | |||||
| mean_type='eps', | |||||
| var_type='fixed_small'): | |||||
| betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta) | |||||
| diffusion = GaussianDiffusion( | |||||
| betas, mean_type=mean_type, var_type=var_type) | |||||
| return diffusion | |||||
| class UnCLIP(nn.Module): | |||||
| def __init__(self, model_dir): | |||||
| super(UnCLIP, self).__init__() | |||||
| self.model_dir = model_dir | |||||
| self.config = json.load(open(f'{model_dir}/{ModelFile.CONFIGURATION}')) | |||||
| # modules | |||||
| self.clip = CLIP(**self.config['clip']).fp16() | |||||
| self.xglm = XGLM(**self.config['xglm']) | |||||
| self.prior = Prior(**self.config['prior']) | |||||
| self.decoder = Decoder(**self.config['decoder']) | |||||
| self.upsampler256 = Upsampler256(**self.config['upsampler256']) | |||||
| self.upsampler1024 = Upsampler1024(**self.config['upsampler1024']) | |||||
| # diffusions | |||||
| self.prior_diffusion = make_diffusion(**self.config['prior_diffusion']) | |||||
| self.decoder_diffusion = make_diffusion( | |||||
| **self.config['decoder_diffusion']) | |||||
| self.upsampler256_diffusion = make_diffusion( | |||||
| **self.config['upsampler256_diffusion']) | |||||
| self.upsampler1024_diffusion = make_diffusion( | |||||
| **self.config['upsampler1024_diffusion']) | |||||
| # tokenizers | |||||
| self.clip_tokenizer = CLIPTokenizer( | |||||
| bpe_path=f'{model_dir}/bpe_simple_vocab_16e6.txt.gz') | |||||
| self.xglm_tokenizer = XGLMTokenizer(model_dir=model_dir) | |||||
| def forward(self, *args, **kwargs): | |||||
| raise NotImplementedError( | |||||
| '"forward" is not implemented. Use "synthesis" instead.') | |||||
| @torch.no_grad() | |||||
| def synthesis(self, | |||||
| text='A photo of a confused grizzly bear in calculus class.', | |||||
| tokenizer='clip', | |||||
| batch_size=4, | |||||
| timesteps_prior=100, | |||||
| timesteps_64=50, | |||||
| timesteps_256=20, | |||||
| timesteps_1024=20, | |||||
| guide_prior=3.0, | |||||
| guide_64=7.0, | |||||
| guide_256=3.0, | |||||
| guide_1024=3.0, | |||||
| eta_prior=0.0, | |||||
| eta_64=0.0, | |||||
| eta_256=0.0, | |||||
| eta_1024=0.0): | |||||
| device = next(self.parameters()).device | |||||
| # check params | |||||
| assert all([ | |||||
| t > 0 and t <= 1000 for t in | |||||
| [timesteps_prior, timesteps_64, timesteps_256, timesteps_1024] | |||||
| ]) | |||||
| assert all([ | |||||
| g > 1 and g < 15 | |||||
| for g in [guide_prior, guide_64, guide_256, guide_1024] | |||||
| ]) | |||||
| assert all([ | |||||
| e >= 0 and e <= 1.0 | |||||
| for e in [eta_prior, eta_64, eta_256, eta_1024] | |||||
| ]) | |||||
| assert batch_size >= 1 and batch_size <= 16 | |||||
| # tokenize the text | |||||
| if tokenizer == 'clip': | |||||
| y = F.normalize( | |||||
| self.clip.textual(self.clip_tokenizer([text]).to(device)), | |||||
| p=2, | |||||
| dim=1) | |||||
| zero_y = F.normalize( | |||||
| self.clip.textual(self.clip_tokenizer(['']).to(device)), | |||||
| p=2, | |||||
| dim=1) | |||||
| elif tokenizer == 'xglm': | |||||
| y = F.normalize( | |||||
| self.xglm(*to_device(self.xglm_tokenizer([text]), device)), | |||||
| p=2, | |||||
| dim=1) | |||||
| zero_y = F.normalize( | |||||
| self.xglm(*to_device(self.xglm_tokenizer(['']), device)), | |||||
| p=2, | |||||
| dim=1) | |||||
| else: | |||||
| raise ValueError( | |||||
| f'Expected tokenizer to be one of "clip" or "xglm", but got {tokenizer}' | |||||
| ) | |||||
| y = math.sqrt(y.size(1)) * y.repeat(batch_size, 1) | |||||
| zero_y = math.sqrt(zero_y.size(1)) * zero_y.repeat(batch_size, 1) | |||||
| # synthesis | |||||
| with amp.autocast(enabled=True): | |||||
| # prior | |||||
| x0 = self.prior_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(y), | |||||
| model=self.prior, | |||||
| model_kwargs=[{ | |||||
| 'y': y | |||||
| }, { | |||||
| 'y': zero_y | |||||
| }], | |||||
| guide_scale=guide_prior, | |||||
| ddim_timesteps=timesteps_prior, | |||||
| eta=eta_prior) | |||||
| # decoder | |||||
| imgs64 = self.decoder_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||||
| model=self.decoder, | |||||
| model_kwargs=[{ | |||||
| 'y': x0 | |||||
| }, { | |||||
| 'y': torch.zeros_like(x0) | |||||
| }], | |||||
| guide_scale=guide_64, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_64, | |||||
| eta=eta_64).clamp_(-1, 1) | |||||
| # upsampler256 | |||||
| imgs256 = F.interpolate( | |||||
| imgs64, scale_factor=4.0, mode='bilinear', align_corners=False) | |||||
| imgs256 = self.upsampler256_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(imgs256), | |||||
| model=self.upsampler256, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs256 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs256 | |||||
| }], | |||||
| guide_scale=guide_256, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_256, | |||||
| eta=eta_256).clamp_(-1, 1) | |||||
| # upsampler1024 | |||||
| imgs1024 = F.interpolate( | |||||
| imgs256, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(imgs1024), | |||||
| model=self.upsampler1024, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs1024 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs1024 | |||||
| }], | |||||
| guide_scale=guide_1024, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_1024, | |||||
| eta=eta_1024).clamp_(-1, 1) | |||||
| # output ([B, C, H, W] within range [0, 1]) | |||||
| imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() | |||||
| imgs1024 = [ | |||||
| Image.fromarray(np.array(u, dtype=np.uint8)) for u in imgs1024 | |||||
| ] | |||||
| return imgs1024 | |||||
| @MODELS.register_module( | |||||
| Tasks.text_to_image_synthesis, module_name=Models.multi_stage_diffusion) | |||||
| class MultiStageDiffusionForTextToImageSynthesis(TorchModel): | |||||
| def __init__(self, model_dir, device_id=-1): | |||||
| super().__init__(model_dir=model_dir, device_id=device_id) | |||||
| model = UnCLIP(model_dir=model_dir) | |||||
| pretrained_params = torch.load( | |||||
| osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') | |||||
| model.load_state_dict(pretrained_params) | |||||
| model.eval() | |||||
| self.device_id = device_id | |||||
| if self.device_id >= 0: | |||||
| self.device = torch.device(f'cuda:{self.device_id}') | |||||
| model.to('cuda:{}'.format(self.device_id)) | |||||
| logger.info('Use GPU: {}'.format(self.device_id)) | |||||
| else: | |||||
| self.device = torch.device('cpu') | |||||
| logger.info('Use CPU for inference') | |||||
| self.model = model | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| if not isinstance(input, dict): | |||||
| raise ValueError( | |||||
| f'Expected the input to be a dictionary, but got {type(input)}' | |||||
| ) | |||||
| if 'text' not in input: | |||||
| raise ValueError('input should contain "text", but not found') | |||||
| # ddim sampling | |||||
| imgs = self.model.synthesis( | |||||
| text=input.get('text'), | |||||
| tokenizer=input.get('tokenizer', 'clip'), | |||||
| batch_size=input.get('batch_size', 4), | |||||
| timesteps_prior=input.get('timesteps_prior', 100), | |||||
| timesteps_64=input.get('timesteps_64', 50), | |||||
| timesteps_256=input.get('timesteps_256', 20), | |||||
| timesteps_1024=input.get('timesteps_1024', 20), | |||||
| guide_prior=input.get('guide_prior', 3.0), | |||||
| guide_64=input.get('guide_64', 7.0), | |||||
| guide_256=input.get('guide_256', 3.0), | |||||
| guide_1024=input.get('guide_1024', 3.0), | |||||
| eta_prior=input.get('eta_prior', 0.0), | |||||
| eta_64=input.get('eta_64', 0.0), | |||||
| eta_256=input.get('eta_256', 0.0), | |||||
| eta_1024=input.get('eta_1024', 0.0)) | |||||
| imgs = [np.array(u)[..., ::-1] for u in imgs] | |||||
| return imgs | |||||
| @@ -0,0 +1,170 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['Prior'] | |||||
| def sinusoidal_embedding(timesteps, dim): | |||||
| # check input | |||||
| half = dim // 2 | |||||
| timesteps = timesteps.float() | |||||
| # compute sinusoidal embedding | |||||
| sinusoid = torch.outer( | |||||
| timesteps, torch.pow(10000, | |||||
| -torch.arange(half).to(timesteps).div(half))) | |||||
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |||||
| if dim % 2 != 0: | |||||
| x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |||||
| return x | |||||
| class SelfAttention(nn.Module): | |||||
| def __init__(self, dim, num_heads): | |||||
| assert dim % num_heads == 0 | |||||
| super(SelfAttention, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = dim // num_heads | |||||
| self.scale = math.pow(self.head_dim, -0.25) | |||||
| # layers | |||||
| self.to_qkv = nn.Linear(dim, dim * 3) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| def forward(self, x, mask): | |||||
| b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim | |||||
| # compute query, key, value | |||||
| q, k, v = self.to_qkv(x).view(b, l, n * 3, c).chunk(3, dim=2) | |||||
| # compute attention | |||||
| attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) | |||||
| if mask is not None: | |||||
| attn = attn.masked_fill(mask[:, :, :l, :l] == 0, float('-inf')) | |||||
| attn = F.softmax(attn.float(), dim=-1).type(attn.dtype) | |||||
| # gather context | |||||
| x = torch.einsum('bnij,bjnc->binc', attn, v) | |||||
| x = x.reshape(b, l, -1) | |||||
| # output | |||||
| x = self.proj(x) | |||||
| return x | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim, num_heads): | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| # layers | |||||
| self.norm1 = nn.LayerNorm(dim) | |||||
| self.attn = SelfAttention(dim, num_heads) | |||||
| self.norm2 = nn.LayerNorm(dim) | |||||
| self.ffn = nn.Sequential( | |||||
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) | |||||
| def forward(self, x, mask=None): | |||||
| x = x + self.attn(self.norm1(x), mask) | |||||
| x = x + self.ffn(self.norm2(x)) | |||||
| return x | |||||
| class Prior(nn.Module): | |||||
| def __init__(self, dim=2048, clip_dim=768, num_heads=32, num_layers=24): | |||||
| super(Prior, self).__init__() | |||||
| self.dim = dim | |||||
| self.clip_dim = clip_dim | |||||
| self.num_heads = num_heads | |||||
| self.num_layers = num_layers | |||||
| # embeddings | |||||
| self.text_embedding = nn.Sequential( | |||||
| nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) | |||||
| self.time_embedding = nn.Sequential( | |||||
| nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim)) | |||||
| self.vision_embedding = nn.Sequential( | |||||
| nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) | |||||
| self.eos_embedding = nn.Parameter(torch.zeros(1, 1, dim)) | |||||
| self.pos_embedding = nn.Parameter(torch.zeros(1, 4, dim)) | |||||
| # transformer | |||||
| self.blocks = nn.ModuleList( | |||||
| [AttentionBlock(dim, num_heads) for _ in range(num_layers)]) | |||||
| self.norm = nn.LayerNorm(dim) | |||||
| # head | |||||
| self.head = nn.Linear(dim, clip_dim) | |||||
| # causal attention mask | |||||
| self.register_buffer('attn_mask', torch.tril(torch.ones(1, 1, 4, 4))) | |||||
| # initialize weights | |||||
| self.init_weights() | |||||
| def forward(self, x, t, y): | |||||
| r"""x: [B, C]. | |||||
| t: [B]. | |||||
| y: [B, C]. | |||||
| """ | |||||
| b = x.size(0) | |||||
| # embeddings of shape [B, L + 4, C] | |||||
| u1 = sinusoidal_embedding(t, self.dim) | |||||
| u2 = [ | |||||
| self.text_embedding(y).unsqueeze(1), | |||||
| self.time_embedding(u1).unsqueeze(1), | |||||
| self.vision_embedding(x).unsqueeze(1), | |||||
| self.eos_embedding.repeat(b, 1, 1) | |||||
| ] | |||||
| x = self.pos_embedding + torch.cat(u2, dim=1) | |||||
| # transformer | |||||
| for block in self.blocks: | |||||
| x = block(x, self.attn_mask) | |||||
| x = self.norm(x) | |||||
| # head | |||||
| x = self.head(x[:, -1]) | |||||
| return x | |||||
| def init_weights(self): | |||||
| std = 0.02 / math.sqrt(2.0 * self.num_layers) | |||||
| for name, m in self.named_modules(): | |||||
| if name.endswith('attn.proj') or name.endswith('ffn.2'): | |||||
| # smaller std for output layers | |||||
| nn.init.normal_(m.weight, std=std) | |||||
| nn.init.zeros_(m.bias) | |||||
| elif isinstance(m, (nn.Linear, nn.Embedding)): | |||||
| nn.init.normal_(m.weight, std=0.02) | |||||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||||
| nn.init.zeros_(m.bias) | |||||
| elif isinstance(m, nn.LayerNorm): | |||||
| nn.init.ones_(m.weight) | |||||
| nn.init.zeros_(m.bias) | |||||
| def param_groups(self): | |||||
| groups = [{ | |||||
| 'params': [ | |||||
| p for n, p in self.named_parameters() | |||||
| if 'norm' in n or n.endswith('bias') | |||||
| ], | |||||
| 'weight_decay': | |||||
| 0.0 | |||||
| }, { | |||||
| 'params': [ | |||||
| p for n, p in self.named_parameters() | |||||
| if not ('norm' in n or n.endswith('bias')) | |||||
| ] | |||||
| }] | |||||
| return groups | |||||
| @@ -0,0 +1,199 @@ | |||||
| # The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP. | |||||
| import gzip | |||||
| import html | |||||
| from functools import lru_cache | |||||
| import ftfy | |||||
| import regex as re | |||||
| import torch | |||||
| from transformers import AutoTokenizer | |||||
| __all__ = ['CLIPTokenizer', 'XGLMTokenizer'] | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
| The reversible bpe codes work on unicode strings. | |||||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
| """ | |||||
| bs = list(range(ord('!'), | |||||
| ord('~') + 1)) + list(range( | |||||
| ord('¡'), | |||||
| ord('¬') + 1)) + list(range(ord('®'), | |||||
| ord('ÿ') + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2**8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2**8 + n) | |||||
| n += 1 | |||||
| cs = [chr(n) for n in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| def basic_clean(text): | |||||
| text = ftfy.fix_text(text) | |||||
| text = html.unescape(html.unescape(text)) | |||||
| return text.strip() | |||||
| def whitespace_clean(text): | |||||
| text = re.sub(r'\s+', ' ', text) | |||||
| text = text.strip() | |||||
| return text | |||||
| class SimpleTokenizer(object): | |||||
| def __init__(self, bpe_path): | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||||
| merges = merges[1:49152 - 256 - 2 + 1] | |||||
| merges = [tuple(merge.split()) for merge in merges] | |||||
| vocab = list(bytes_to_unicode().values()) | |||||
| vocab = vocab + [v + '</w>' for v in vocab] | |||||
| for merge in merges: | |||||
| vocab.append(''.join(merge)) | |||||
| vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||||
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||||
| self.cache = { | |||||
| '<|startoftext|>': '<|startoftext|>', | |||||
| '<|endoftext|>': '<|endoftext|>' | |||||
| } | |||||
| self.pat = re.compile( | |||||
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||||
| re.IGNORECASE) | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token + '</w>' | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except Exception: | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def encode(self, text): | |||||
| bpe_tokens = [] | |||||
| text = whitespace_clean(basic_clean(text)).lower() | |||||
| for token in re.findall(self.pat, text): | |||||
| token = ''.join(self.byte_encoder[b] | |||||
| for b in token.encode('utf-8')) | |||||
| bpe_tokens.extend(self.encoder[bpe_token] | |||||
| for bpe_token in self.bpe(token).split(' ')) | |||||
| return bpe_tokens | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
| 'utf-8', errors='replace').replace('</w>', ' ') | |||||
| return text | |||||
| class CLIPTokenizer(object): | |||||
| r"""CLIP tokenizer, adapted from https://github.com/openai/CLIP. | |||||
| """ | |||||
| def __init__(self, bpe_path, length=77): | |||||
| self.bpe_path = bpe_path | |||||
| self.length = length | |||||
| # init tokenizer | |||||
| self.tokenizer = SimpleTokenizer(bpe_path=bpe_path) | |||||
| self.sos_token = self.tokenizer.encoder['<|startoftext|>'] | |||||
| self.eos_token = self.tokenizer.encoder['<|endoftext|>'] | |||||
| self.vocab_size = len(self.tokenizer.encoder) | |||||
| def __call__(self, sequence): | |||||
| if isinstance(sequence, str): | |||||
| return torch.LongTensor(self._tokenizer(sequence)) | |||||
| elif isinstance(sequence, list): | |||||
| return torch.LongTensor([self._tokenizer(u) for u in sequence]) | |||||
| else: | |||||
| raise TypeError( | |||||
| f'Expected the "sequence" to be a string or a list, but got {type(sequence)}' | |||||
| ) | |||||
| def _tokenizer(self, text): | |||||
| tokens = self.tokenizer.encode(text)[:self.length - 2] | |||||
| tokens = [self.sos_token] + tokens + [self.eos_token] | |||||
| tokens = tokens + [0] * (self.length - len(tokens)) | |||||
| return tokens | |||||
| class XGLMTokenizer(object): | |||||
| r"""A wrapper of HuggingFace's XGLM tokenizer. | |||||
| """ | |||||
| def __init__(self, model_dir, length=77, **kwargs): | |||||
| self.length = length | |||||
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir, **kwargs) | |||||
| self.vocab_size = self.tokenizer.vocab_size | |||||
| def __call__(self, sequence, **kwargs): | |||||
| _kwargs = { | |||||
| 'return_tensors': 'pt', | |||||
| 'padding': 'max_length', | |||||
| 'truncation': True, | |||||
| 'max_length': self.length | |||||
| } | |||||
| _kwargs.update(**kwargs) | |||||
| tokens = self.tokenizer(sequence, **_kwargs) | |||||
| return tokens.input_ids, tokens.attention_mask | |||||
| @@ -0,0 +1,466 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['Upsampler256', 'Upsampler1024'] | |||||
| def sinusoidal_embedding(timesteps, dim): | |||||
| # check input | |||||
| half = dim // 2 | |||||
| timesteps = timesteps.float() | |||||
| # compute sinusoidal embedding | |||||
| sinusoid = torch.outer( | |||||
| timesteps, torch.pow(10000, | |||||
| -torch.arange(half).to(timesteps).div(half))) | |||||
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |||||
| if dim % 2 != 0: | |||||
| x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |||||
| return x | |||||
| class Resample(nn.Module): | |||||
| def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): | |||||
| assert scale_factor in [0.5, 1.0, 2.0] | |||||
| super(Resample, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.out_dim = out_dim | |||||
| self.scale_factor = scale_factor | |||||
| self.use_conv = use_conv | |||||
| # layers | |||||
| if scale_factor == 2.0: | |||||
| self.resample = nn.Sequential( | |||||
| nn.Upsample(scale_factor=scale_factor, mode='nearest'), | |||||
| nn.Conv2d(in_dim, out_dim, 3, padding=1) | |||||
| if use_conv else nn.Identity()) | |||||
| elif scale_factor == 0.5: | |||||
| self.resample = nn.Conv2d( | |||||
| in_dim, out_dim, 3, stride=2, | |||||
| padding=1) if use_conv else nn.AvgPool2d( | |||||
| kernel_size=2, stride=2) | |||||
| else: | |||||
| self.resample = nn.Identity() | |||||
| def forward(self, x): | |||||
| return self.resample(x) | |||||
| class ResidualBlock(nn.Module): | |||||
| def __init__(self, | |||||
| in_dim, | |||||
| embed_dim, | |||||
| out_dim, | |||||
| use_scale_shift_norm=True, | |||||
| scale_factor=1.0, | |||||
| dropout=0.0): | |||||
| super(ResidualBlock, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.embed_dim = embed_dim | |||||
| self.out_dim = out_dim | |||||
| self.use_scale_shift_norm = use_scale_shift_norm | |||||
| self.scale_factor = scale_factor | |||||
| # layers | |||||
| self.layer1 = nn.Sequential( | |||||
| nn.GroupNorm(32, in_dim), nn.SiLU(), | |||||
| nn.Conv2d(in_dim, out_dim, 3, padding=1)) | |||||
| self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) | |||||
| self.embedding = nn.Sequential( | |||||
| nn.SiLU(), | |||||
| nn.Linear(embed_dim, | |||||
| out_dim * 2 if use_scale_shift_norm else out_dim)) | |||||
| self.layer2 = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |||||
| nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |||||
| self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( | |||||
| in_dim, out_dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.layer2[-1].weight) | |||||
| def forward(self, x, e): | |||||
| identity = self.resample(x) | |||||
| x = self.layer1[-1](self.resample(self.layer1[:-1](x))) | |||||
| e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) | |||||
| if self.use_scale_shift_norm: | |||||
| scale, shift = e.chunk(2, dim=1) | |||||
| x = self.layer2[0](x) * (1 + scale) + shift | |||||
| x = self.layer2[1:](x) | |||||
| else: | |||||
| x = x + e | |||||
| x = self.layer2(x) | |||||
| x = x + self.shortcut(identity) | |||||
| return x | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): | |||||
| # consider head_dim first, then num_heads | |||||
| num_heads = dim // head_dim if head_dim else num_heads | |||||
| head_dim = dim // num_heads | |||||
| assert num_heads * head_dim == dim | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.context_dim = context_dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = head_dim | |||||
| self.scale = math.pow(head_dim, -0.25) | |||||
| # layers | |||||
| self.norm = nn.GroupNorm(32, dim) | |||||
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |||||
| if context_dim is not None: | |||||
| self.context_kv = nn.Linear(context_dim, dim * 2) | |||||
| self.proj = nn.Conv2d(dim, dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.proj.weight) | |||||
| def forward(self, x, context=None): | |||||
| r"""x: [B, C, H, W]. | |||||
| context: [B, L, C] or None. | |||||
| """ | |||||
| identity = x | |||||
| b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim | |||||
| # compute query, key, value | |||||
| x = self.norm(x) | |||||
| q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) | |||||
| if context is not None: | |||||
| ck, cv = self.context_kv(context).reshape(b, -1, n * 2, | |||||
| d).permute(0, 2, 3, | |||||
| 1).chunk( | |||||
| 2, dim=1) | |||||
| k = torch.cat([ck, k], dim=-1) | |||||
| v = torch.cat([cv, v], dim=-1) | |||||
| # compute attention | |||||
| attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) | |||||
| attn = F.softmax(attn, dim=-1) | |||||
| # gather context | |||||
| x = torch.matmul(v, attn.transpose(-1, -2)) | |||||
| x = x.reshape(b, c, h, w) | |||||
| # output | |||||
| x = self.proj(x) | |||||
| return x + identity | |||||
| class Upsampler256(nn.Module): | |||||
| def __init__(self, | |||||
| in_dim=6, | |||||
| dim=320, | |||||
| y_dim=768, | |||||
| context_dim=512, | |||||
| out_dim=3, | |||||
| dim_mult=[1, 2, 3, 4], | |||||
| num_heads=None, | |||||
| head_dim=64, | |||||
| num_res_blocks=3, | |||||
| attn_scales=[1 / 8], | |||||
| resblock_resample=True, | |||||
| use_scale_shift_norm=True, | |||||
| dropout=0.1): | |||||
| embed_dim = dim * 4 | |||||
| super(Upsampler256, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.dim = dim | |||||
| self.y_dim = y_dim | |||||
| self.context_dim = context_dim | |||||
| self.embed_dim = embed_dim | |||||
| self.out_dim = out_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = head_dim | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| self.resblock_resample = resblock_resample | |||||
| self.use_scale_shift_norm = use_scale_shift_norm | |||||
| # params | |||||
| enc_dims = [dim * u for u in [1] + dim_mult] | |||||
| dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||||
| shortcut_dims = [] | |||||
| scale = 1.0 | |||||
| # embeddings | |||||
| self.time_embedding = nn.Sequential( | |||||
| nn.Linear(dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| self.y_embedding = nn.Sequential( | |||||
| nn.Linear(y_dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| self.context_embedding = nn.Sequential( | |||||
| nn.Linear(y_dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, context_dim * 4)) | |||||
| # encoder | |||||
| self.encoder = nn.ModuleList( | |||||
| [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |||||
| shortcut_dims.append(dim) | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |||||
| for j in range(num_res_blocks): | |||||
| # residual (+attention) blocks | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 1.0, dropout) | |||||
| ]) | |||||
| if scale in attn_scales: | |||||
| block.append( | |||||
| AttentionBlock(out_dim, context_dim, num_heads, | |||||
| head_dim)) | |||||
| in_dim = out_dim | |||||
| self.encoder.append(block) | |||||
| shortcut_dims.append(out_dim) | |||||
| # downsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |||||
| if resblock_resample: | |||||
| downsample = ResidualBlock(out_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 0.5, | |||||
| dropout) | |||||
| else: | |||||
| downsample = Resample( | |||||
| out_dim, out_dim, 0.5, use_conv=True) | |||||
| shortcut_dims.append(out_dim) | |||||
| scale /= 2.0 | |||||
| self.encoder.append(downsample) | |||||
| # middle | |||||
| self.middle = nn.ModuleList([ | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, | |||||
| 1.0, dropout), | |||||
| AttentionBlock(out_dim, context_dim, num_heads, head_dim), | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, | |||||
| 1.0, dropout) | |||||
| ]) | |||||
| # decoder | |||||
| self.decoder = nn.ModuleList() | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): | |||||
| for j in range(num_res_blocks + 1): | |||||
| # residual (+attention) blocks | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, | |||||
| out_dim, use_scale_shift_norm, 1.0, dropout) | |||||
| ]) | |||||
| if scale in attn_scales: | |||||
| block.append( | |||||
| AttentionBlock(out_dim, context_dim, num_heads, | |||||
| head_dim)) | |||||
| in_dim = out_dim | |||||
| # upsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks: | |||||
| if resblock_resample: | |||||
| upsample = ResidualBlock(out_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 2.0, | |||||
| dropout) | |||||
| else: | |||||
| upsample = Resample( | |||||
| out_dim, out_dim, 2.0, use_conv=True) | |||||
| scale *= 2.0 | |||||
| block.append(upsample) | |||||
| self.decoder.append(block) | |||||
| # head | |||||
| self.head = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), | |||||
| nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.head[-1].weight) | |||||
| def forward(self, x, t, y, concat): | |||||
| # embeddings | |||||
| x = torch.cat([x, concat], dim=1) | |||||
| e = self.time_embedding(sinusoidal_embedding( | |||||
| t, self.dim)) + self.y_embedding(y) | |||||
| context = self.context_embedding(y).view(-1, 4, self.context_dim) | |||||
| # encoder | |||||
| xs = [] | |||||
| for block in self.encoder: | |||||
| x = self._forward_single(block, x, e, context) | |||||
| xs.append(x) | |||||
| # middle | |||||
| for block in self.middle: | |||||
| x = self._forward_single(block, x, e, context) | |||||
| # decoder | |||||
| for block in self.decoder: | |||||
| x = torch.cat([x, xs.pop()], dim=1) | |||||
| x = self._forward_single(block, x, e, context) | |||||
| # head | |||||
| x = self.head(x) | |||||
| return x | |||||
| def _forward_single(self, module, x, e, context): | |||||
| if isinstance(module, ResidualBlock): | |||||
| x = module(x, e) | |||||
| elif isinstance(module, AttentionBlock): | |||||
| x = module(x, context) | |||||
| elif isinstance(module, nn.ModuleList): | |||||
| for block in module: | |||||
| x = self._forward_single(block, x, e, context) | |||||
| else: | |||||
| x = module(x) | |||||
| return x | |||||
| class Upsampler1024(nn.Module): | |||||
| def __init__(self, | |||||
| in_dim=6, | |||||
| dim=192, | |||||
| y_dim=768, | |||||
| out_dim=3, | |||||
| dim_mult=[1, 1, 2, 2, 4, 4], | |||||
| num_res_blocks=2, | |||||
| resblock_resample=True, | |||||
| use_scale_shift_norm=True, | |||||
| dropout=0.0): | |||||
| embed_dim = dim * 4 | |||||
| super(Upsampler1024, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.dim = dim | |||||
| self.y_dim = y_dim | |||||
| self.out_dim = out_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.resblock_resample = resblock_resample | |||||
| self.use_scale_shift_norm = use_scale_shift_norm | |||||
| # params | |||||
| enc_dims = [dim * u for u in [1] + dim_mult] | |||||
| dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||||
| shortcut_dims = [] | |||||
| scale = 1.0 | |||||
| # embedding | |||||
| self.time_embedding = nn.Sequential( | |||||
| nn.Linear(dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| self.y_embedding = nn.Sequential( | |||||
| nn.Linear(y_dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| # encoder | |||||
| self.encoder = nn.ModuleList( | |||||
| [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |||||
| shortcut_dims.append(dim) | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |||||
| for j in range(num_res_blocks): | |||||
| # residual block | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 1.0, dropout) | |||||
| ]) | |||||
| shortcut_dims.append(out_dim) | |||||
| in_dim = out_dim | |||||
| self.encoder.append(block) | |||||
| # downsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |||||
| if resblock_resample: | |||||
| downsample = ResidualBlock(out_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 0.5, | |||||
| dropout) | |||||
| else: | |||||
| downsample = Resample( | |||||
| out_dim, out_dim, 0.5, use_conv=True) | |||||
| shortcut_dims.append(out_dim) | |||||
| scale /= 2.0 | |||||
| self.encoder.append(downsample) | |||||
| # middle | |||||
| self.middle = nn.ModuleList([ | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, | |||||
| 1.0, dropout), | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, | |||||
| 1.0, dropout) | |||||
| ]) | |||||
| # decoder | |||||
| self.decoder = nn.ModuleList() | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): | |||||
| for j in range(num_res_blocks + 1): | |||||
| # residual block | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, | |||||
| out_dim, use_scale_shift_norm, 1.0, dropout) | |||||
| ]) | |||||
| in_dim = out_dim | |||||
| # upsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks: | |||||
| if resblock_resample: | |||||
| upsample = ResidualBlock(out_dim, embed_dim, out_dim, | |||||
| use_scale_shift_norm, 2.0, | |||||
| dropout) | |||||
| else: | |||||
| upsample = Resample( | |||||
| out_dim, out_dim, 2.0, use_conv=True) | |||||
| scale *= 2.0 | |||||
| block.append(upsample) | |||||
| self.decoder.append(block) | |||||
| # head | |||||
| self.head = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), | |||||
| nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.head[-1].weight) | |||||
| def forward(self, x, t, y, concat): | |||||
| # embedding | |||||
| x = torch.cat([x, concat], dim=1) | |||||
| e = self.time_embedding(sinusoidal_embedding( | |||||
| t, self.dim)) + self.y_embedding(y) | |||||
| # encoder | |||||
| xs = [] | |||||
| for block in self.encoder: | |||||
| x = self._forward_single(block, x, e) | |||||
| xs.append(x) | |||||
| # middle | |||||
| for block in self.middle: | |||||
| x = self._forward_single(block, x, e) | |||||
| # decoder | |||||
| for block in self.decoder: | |||||
| x = torch.cat([x, xs.pop()], dim=1) | |||||
| x = self._forward_single(block, x, e) | |||||
| # head | |||||
| x = self.head(x) | |||||
| return x | |||||
| def _forward_single(self, module, x, e): | |||||
| if isinstance(module, ResidualBlock): | |||||
| x = module(x, e) | |||||
| elif isinstance(module, nn.ModuleList): | |||||
| for block in module: | |||||
| x = self._forward_single(block, x, e) | |||||
| else: | |||||
| x = module(x) | |||||
| return x | |||||
| @@ -0,0 +1,205 @@ | |||||
| # The implementation here is modified based on HuggingFace XGLM, publicly available | |||||
| # at https://github.com/huggingface/transformers. | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['XGLM'] | |||||
| def sinusoidal_embedding(seq_len, dim, pad_token=None): | |||||
| half = dim // 2 | |||||
| sinusoid = torch.outer( | |||||
| torch.arange(seq_len, dtype=torch.float32), | |||||
| torch.pow(10000, | |||||
| -torch.arange(half, dtype=torch.float32).div(half - 1))) | |||||
| x = torch.cat([torch.sin(sinusoid), torch.cos(sinusoid)], dim=1) | |||||
| if dim % 2 == 1: | |||||
| x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |||||
| if pad_token is not None: | |||||
| x[pad_token, :] = 0 | |||||
| return x | |||||
| class SinusoidalEmbedding(nn.Module): | |||||
| def __init__(self, seq_len, dim, pad_token): | |||||
| super(SinusoidalEmbedding, self).__init__() | |||||
| self.seq_len = seq_len | |||||
| self.dim = dim | |||||
| self.pad_token = pad_token | |||||
| self.register_buffer('weight', | |||||
| sinusoidal_embedding(seq_len + 2, dim, pad_token)) | |||||
| def forward(self, tokens): | |||||
| mask = tokens.ne(self.pad_token).long() | |||||
| indices = torch.cumsum(mask, dim=1) * mask + self.pad_token | |||||
| pos_embeds = self.weight.index_select(0, indices.view(-1)).view( | |||||
| *tokens.shape, -1) | |||||
| return pos_embeds | |||||
| class GELU(nn.Module): | |||||
| def forward(self, x): | |||||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
| class SelfAttention(nn.Module): | |||||
| def __init__(self, dim, num_heads, dropout=0.1): | |||||
| assert dim % num_heads == 0 | |||||
| super(SelfAttention, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = dim // num_heads | |||||
| self.scale = 1.0 / math.sqrt(self.head_dim) | |||||
| # layers | |||||
| self.q = nn.Linear(dim, dim) | |||||
| self.k = nn.Linear(dim, dim) | |||||
| self.v = nn.Linear(dim, dim) | |||||
| self.o = nn.Linear(dim, dim) | |||||
| self.dropout = nn.Dropout(dropout) | |||||
| def forward(self, x, mask=None): | |||||
| r"""x: [B, L, C]. | |||||
| mask: [B, *, L, L] or None. | |||||
| """ | |||||
| b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim | |||||
| # compute query, key, value | |||||
| q = self.q(x).view(b, l, n, c) | |||||
| k = self.k(x).view(b, l, n, c) | |||||
| v = self.v(x).view(b, l, n, c) | |||||
| # compute attention | |||||
| attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k) | |||||
| if mask is not None: | |||||
| attn = attn.masked_fill(mask == 0, float('-inf')) | |||||
| attn = F.softmax(attn, dim=-1) | |||||
| attn = self.dropout(attn) | |||||
| # gather context | |||||
| x = torch.einsum('bnij,bjnc->binc', attn, v) | |||||
| x = x.reshape(b, l, -1) | |||||
| # output | |||||
| x = self.o(x) | |||||
| x = self.dropout(x) | |||||
| return x | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim, ffn_dim, ffn_act, num_heads, dropout=0.1): | |||||
| assert ffn_act in ['gelu', 'relu'] | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.ffn_dim = ffn_dim | |||||
| self.ffn_act = ffn_act | |||||
| self.num_heads = num_heads | |||||
| # layers | |||||
| self.norm1 = nn.LayerNorm(dim) | |||||
| self.attn = SelfAttention(dim, num_heads, dropout) | |||||
| self.norm2 = nn.LayerNorm(dim) | |||||
| self.ffn = nn.Sequential( | |||||
| nn.Linear(dim, ffn_dim), | |||||
| GELU() if ffn_act == 'gelu' else nn.ReLU(inplace=True), | |||||
| nn.Linear(ffn_dim, dim), nn.Dropout(dropout)) | |||||
| def forward(self, x, mask=None): | |||||
| x = x + self.attn(self.norm1(x), mask) | |||||
| x = x + self.ffn(self.norm2(x)) | |||||
| return x | |||||
| class XGLM(nn.Module): | |||||
| r"""A multilingual GPT model with an embedding head. | |||||
| """ | |||||
| def __init__(self, | |||||
| vocab_size=256008, | |||||
| max_seq_len=2048, | |||||
| dim=1024, | |||||
| ffn_dim=4096, | |||||
| ffn_act='gelu', | |||||
| embed_dim=768, | |||||
| num_heads=16, | |||||
| num_layers=24, | |||||
| pad_token=1, | |||||
| dropout=0.1): | |||||
| super(XGLM, self).__init__() | |||||
| self.vocab_size = vocab_size | |||||
| self.max_seq_len = max_seq_len | |||||
| self.dim = dim | |||||
| self.ffn_dim = ffn_dim | |||||
| self.ffn_act = ffn_act | |||||
| self.embed_dim = embed_dim | |||||
| self.num_heads = num_heads | |||||
| self.num_layers = num_layers | |||||
| self.pad_token = pad_token | |||||
| self.scale = math.sqrt(dim) # rescale token embedings | |||||
| # layers | |||||
| self.token_embedding = nn.Embedding(vocab_size, dim, pad_token) | |||||
| self.pos_embedding = SinusoidalEmbedding(max_seq_len, dim, pad_token) | |||||
| self.eos_embedding = nn.Parameter(torch.randn(1, 1, dim)) | |||||
| self.dropout = nn.Dropout(dropout) | |||||
| self.blocks = nn.ModuleList([ | |||||
| AttentionBlock(dim, ffn_dim, ffn_act, num_heads, dropout) | |||||
| for _ in range(num_layers) | |||||
| ]) | |||||
| self.norm = nn.LayerNorm(dim) | |||||
| self.head = nn.Linear(dim, embed_dim, bias=False) | |||||
| # causal attention mask | |||||
| self.register_buffer( | |||||
| 'attn_mask', | |||||
| torch.tril(torch.ones(1, 1, 1 + max_seq_len, 1 + max_seq_len))) | |||||
| # init weights | |||||
| self.apply(self.init_weights) | |||||
| def forward(self, tokens, mask=None): | |||||
| r"""tokens: [B, L]. | |||||
| mask: [B, L]. | |||||
| """ | |||||
| b, seq_len = tokens.size(0), 1 + tokens.size(1) | |||||
| # embeddings | |||||
| x = self.scale * self.token_embedding(tokens) | |||||
| x = torch.cat([x, self.eos_embedding.repeat(b, 1, 1)], dim=1) | |||||
| # x = x + self.pos_embedding(tokens) | |||||
| x = self.dropout(x) | |||||
| # attention mask | |||||
| if mask is None: | |||||
| mask = self.attn_mask[:, :, :seq_len, :seq_len].repeat(b, 1, 1, 1) | |||||
| else: | |||||
| mask = self.attn_mask[:, :, :seq_len, :seq_len] * torch.cat( | |||||
| [mask, torch.zeros_like(mask[:, :1])], dim=1).view( | |||||
| b, 1, 1, seq_len) | |||||
| # transformer | |||||
| for block in self.blocks: | |||||
| x = block(x, mask) | |||||
| x = self.norm(x) | |||||
| # head | |||||
| logits = self.head(x[:, -1]) | |||||
| return logits | |||||
| def init_weights(self, m): | |||||
| if isinstance(m, nn.Linear): | |||||
| nn.init.normal_(m.weight, std=0.02) | |||||
| if m.bias is not None: | |||||
| nn.init.zeros_(m.bias) | |||||
| elif isinstance(m, nn.Embedding): | |||||
| nn.init.normal_(m.weight, std=0.02) | |||||
| if m.padding_idx is not None: | |||||
| nn.init.zeros_(m.weight[m.padding_idx]) | |||||
| @@ -3,7 +3,8 @@ from typing import Any, Dict, Optional | |||||
| import torch | import torch | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.multi_modal import OfaForTextToImageSynthesis | |||||
| from modelscope.models.multi_modal import ( | |||||
| MultiStageDiffusionForTextToImageSynthesis, OfaForTextToImageSynthesis) | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Model, Pipeline | from modelscope.pipelines.base import Input, Model, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| @@ -48,7 +49,9 @@ class TextToImageSynthesisPipeline(Pipeline): | |||||
| return input | return input | ||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| if isinstance(self.model, OfaForTextToImageSynthesis): | |||||
| if isinstance(self.model, | |||||
| (OfaForTextToImageSynthesis, | |||||
| MultiStageDiffusionForTextToImageSynthesis)): | |||||
| return self.model(input) | return self.model(input) | ||||
| return self.model.generate(input) | return self.model.generate(input) | ||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class MultiStageDiffusionTest(unittest.TestCase): | |||||
| model_id = 'damo/cv_diffusion_text-to-image-synthesis' | |||||
| test_text = {'text': 'Photograph of a baby chicken wearing sunglasses'} | |||||
| @unittest.skip( | |||||
| 'skip test since the pretrained model is not publicly available') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| pipe_line_text_to_image_synthesis = pipeline( | |||||
| task=Tasks.text_to_image_synthesis, model=model) | |||||
| img = pipe_line_text_to_image_synthesis( | |||||
| self.test_text)[OutputKeys.OUTPUT_IMG] | |||||
| print(np.sum(np.abs(img))) | |||||
| @unittest.skip( | |||||
| 'skip test since the pretrained model is not publicly available') | |||||
| def test_run_with_model_name(self): | |||||
| pipe_line_text_to_image_synthesis = pipeline( | |||||
| task=Tasks.text_to_image_synthesis, model=self.model_id) | |||||
| img = pipe_line_text_to_image_synthesis( | |||||
| self.test_text)[OutputKeys.OUTPUT_IMG] | |||||
| print(np.sum(np.abs(img))) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||