# coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import logging import math from os.path import join as pjoin import torch import torch.nn as nn import numpy as np from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage from . import vit_seg_configs as configs from .vit_seg_modeling_resnet_skip import ResNetV2 logger = logging.getLogger(__name__) ATTENTION_Q = "MultiHeadDotProductAttention_1/query" ATTENTION_K = "MultiHeadDotProductAttention_1/key" ATTENTION_V = "MultiHeadDotProductAttention_1/value" ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" FC_0 = "MlpBlock_3/Dense_0" FC_1 = "MlpBlock_3/Dense_1" ATTENTION_NORM = "LayerNorm_0" MLP_NORM = "LayerNorm_2" def np2th(weights, conv=False): """Possibly convert HWIO to OIHW.""" if conv: weights = weights.transpose([3, 2, 0, 1]) return torch.from_numpy(weights) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} class Attention(nn.Module): def __init__(self, config, vis): super(Attention, self).__init__() self.vis = vis self.num_attention_heads = config.transformer["num_heads"] self.attention_head_size = int(config.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = Linear(config.hidden_size, self.all_head_size) self.key = Linear(config.hidden_size, self.all_head_size) self.value = Linear(config.hidden_size, self.all_head_size) self.out = Linear(config.hidden_size, config.hidden_size) self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.softmax = Softmax(dim=-1) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) weights = attention_probs if self.vis else None attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) return attention_output, weights class Mlp(nn.Module): def __init__(self, config): super(Mlp, self).__init__() self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) self.act_fn = ACT2FN["gelu"] self.dropout = Dropout(config.transformer["dropout_rate"]) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6) def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class Embeddings(nn.Module): """Construct the embeddings from patch, position embeddings. """ def __init__(self, config, img_size, in_channels=3): super(Embeddings, self).__init__() self.hybrid = None self.config = config img_size = _pair(img_size) if config.patches.get("grid") is not None: # ResNet grid_size = config.patches["grid"] patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) self.hybrid = True else: patch_size = _pair(config.patches["size"]) n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) self.hybrid = False if self.hybrid: self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) in_channels = self.hybrid_model.width * 16 self.patch_embeddings = Conv2d(in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size) self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) self.dropout = Dropout(config.transformer["dropout_rate"]) def forward(self, x): if self.hybrid: x, features = self.hybrid_model(x) else: features = None x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) x = x.flatten(2) x = x.transpose(-1, -2) # (B, n_patches, hidden) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings, features class Block(nn.Module): def __init__(self, config, vis): super(Block, self).__init__() self.hidden_size = config.hidden_size self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn = Mlp(config) self.attn = Attention(config, vis) def forward(self, x): h = x x = self.attention_norm(x) x, weights = self.attn(x) x = x + h h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h return x, weights def load_from(self, weights, n_block): ROOT = f"Transformer/encoderblock_{n_block}" with torch.no_grad(): query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) self.attn.query.weight.copy_(query_weight) self.attn.key.weight.copy_(key_weight) self.attn.value.weight.copy_(value_weight) self.attn.out.weight.copy_(out_weight) self.attn.query.bias.copy_(query_bias) self.attn.key.bias.copy_(key_bias) self.attn.value.bias.copy_(value_bias) self.attn.out.bias.copy_(out_bias) mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() self.ffn.fc1.weight.copy_(mlp_weight_0) self.ffn.fc2.weight.copy_(mlp_weight_1) self.ffn.fc1.bias.copy_(mlp_bias_0) self.ffn.fc2.bias.copy_(mlp_bias_1) self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) class Encoder(nn.Module): def __init__(self, config, vis): super(Encoder, self).__init__() self.vis = vis self.layer = nn.ModuleList() self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) for _ in range(config.transformer["num_layers"]): layer = Block(config, vis) self.layer.append(copy.deepcopy(layer)) def forward(self, hidden_states): attn_weights = [] for layer_block in self.layer: hidden_states, weights = layer_block(hidden_states) if self.vis: attn_weights.append(weights) encoded = self.encoder_norm(hidden_states) return encoded, attn_weights class Transformer(nn.Module): def __init__(self, config, img_size, vis): super(Transformer, self).__init__() self.embeddings = Embeddings(config, img_size=img_size) self.encoder = Encoder(config, vis) def forward(self, input_ids): embedding_output, features = self.embeddings(input_ids) encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) return encoded, attn_weights, features class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) bn = nn.BatchNorm2d(out_channels) super(Conv2dReLU, self).__init__(conv, bn, relu) class DecoderBlock(nn.Module): def __init__( self, in_channels, out_channels, skip_channels=0, use_batchnorm=True, ): super().__init__() self.conv1 = Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, x, skip=None): x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class SegmentationHead(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() super().__init__(conv2d, upsampling) class DecoderCup(nn.Module): def __init__(self, config): super().__init__() self.config = config head_channels = 512 self.conv_more = Conv2dReLU( config.hidden_size, head_channels, kernel_size=3, padding=1, use_batchnorm=True, ) decoder_channels = config.decoder_channels in_channels = [head_channels] + list(decoder_channels[:-1]) out_channels = decoder_channels if self.config.n_skip != 0: skip_channels = self.config.skip_channels for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip skip_channels[3-i]=0 else: skip_channels=[0,0,0,0] blocks = [ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) ] self.blocks = nn.ModuleList(blocks) def forward(self, hidden_states, features=None): B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) x = hidden_states.permute(0, 2, 1) x = x.contiguous().view(B, hidden, h, w) x = self.conv_more(x) for i, decoder_block in enumerate(self.blocks): if features is not None: skip = features[i] if (i < self.config.n_skip) else None else: skip = None x = decoder_block(x, skip=skip) return x class VisionTransformer(nn.Module): def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): super(VisionTransformer, self).__init__() self.num_classes = num_classes self.zero_head = zero_head self.classifier = config.classifier self.transformer = Transformer(config, img_size, vis) self.decoder = DecoderCup(config) self.segmentation_head = SegmentationHead( in_channels=config['decoder_channels'][-1], out_channels=config['n_classes'], kernel_size=3, ) self.config = config def forward(self, x): if x.size()[1] == 1: x = x.repeat(1,3,1,1) x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) x = self.decoder(x, features) logits = self.segmentation_head(x) return logits def load_from(self, weights): with torch.no_grad(): res_weight = weights self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) posemb_new = self.transformer.embeddings.position_embeddings if posemb.size() == posemb_new.size(): self.transformer.embeddings.position_embeddings.copy_(posemb) elif posemb.size()[1]-1 == posemb_new.size()[1]: posemb = posemb[:, 1:] self.transformer.embeddings.position_embeddings.copy_(posemb) else: logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) ntok_new = posemb_new.size(1) if self.classifier == "seg": _, posemb_grid = posemb[:, :1], posemb[0, 1:] gs_old = int(np.sqrt(len(posemb_grid))) gs_new = int(np.sqrt(ntok_new)) print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) zoom = (gs_new / gs_old, gs_new / gs_old, 1) posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) posemb = posemb_grid self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) # Encoder whole for bname, block in self.transformer.encoder.named_children(): for uname, unit in block.named_children(): unit.load_from(weights, n_block=uname) if self.transformer.embeddings.hybrid: self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): for uname, unit in block.named_children(): unit.load_from(res_weight, n_block=bname, n_unit=uname) CONFIGS = { 'ViT-B_16': configs.get_b16_config(), 'ViT-B_32': configs.get_b32_config(), 'ViT-L_16': configs.get_l16_config(), 'ViT-L_32': configs.get_l32_config(), 'ViT-H_14': configs.get_h14_config(), 'R50-ViT-B_16': configs.get_r50_b16_config(), 'R50-ViT-L_16': configs.get_r50_l16_config(), 'testing': configs.get_testing(), }