You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vit_seg_modeling.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # coding=utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import copy
  6. import logging
  7. import math
  8. from os.path import join as pjoin
  9. import torch
  10. import torch.nn as nn
  11. import numpy as np
  12. from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
  13. from torch.nn.modules.utils import _pair
  14. from scipy import ndimage
  15. from . import vit_seg_configs as configs
  16. from .vit_seg_modeling_resnet_skip import ResNetV2
  17. logger = logging.getLogger(__name__)
  18. ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
  19. ATTENTION_K = "MultiHeadDotProductAttention_1/key"
  20. ATTENTION_V = "MultiHeadDotProductAttention_1/value"
  21. ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
  22. FC_0 = "MlpBlock_3/Dense_0"
  23. FC_1 = "MlpBlock_3/Dense_1"
  24. ATTENTION_NORM = "LayerNorm_0"
  25. MLP_NORM = "LayerNorm_2"
  26. def np2th(weights, conv=False):
  27. """Possibly convert HWIO to OIHW."""
  28. if conv:
  29. weights = weights.transpose([3, 2, 0, 1])
  30. return torch.from_numpy(weights)
  31. def swish(x):
  32. return x * torch.sigmoid(x)
  33. ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
  34. class Attention(nn.Module):
  35. def __init__(self, config, vis):
  36. super(Attention, self).__init__()
  37. self.vis = vis
  38. self.num_attention_heads = config.transformer["num_heads"]
  39. self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
  40. self.all_head_size = self.num_attention_heads * self.attention_head_size
  41. self.query = Linear(config.hidden_size, self.all_head_size)
  42. self.key = Linear(config.hidden_size, self.all_head_size)
  43. self.value = Linear(config.hidden_size, self.all_head_size)
  44. self.out = Linear(config.hidden_size, config.hidden_size)
  45. self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
  46. self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
  47. self.softmax = Softmax(dim=-1)
  48. def transpose_for_scores(self, x):
  49. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  50. x = x.view(*new_x_shape)
  51. return x.permute(0, 2, 1, 3)
  52. def forward(self, hidden_states):
  53. mixed_query_layer = self.query(hidden_states)
  54. mixed_key_layer = self.key(hidden_states)
  55. mixed_value_layer = self.value(hidden_states)
  56. query_layer = self.transpose_for_scores(mixed_query_layer)
  57. key_layer = self.transpose_for_scores(mixed_key_layer)
  58. value_layer = self.transpose_for_scores(mixed_value_layer)
  59. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  60. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  61. attention_probs = self.softmax(attention_scores)
  62. weights = attention_probs if self.vis else None
  63. attention_probs = self.attn_dropout(attention_probs)
  64. context_layer = torch.matmul(attention_probs, value_layer)
  65. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  66. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  67. context_layer = context_layer.view(*new_context_layer_shape)
  68. attention_output = self.out(context_layer)
  69. attention_output = self.proj_dropout(attention_output)
  70. return attention_output, weights
  71. class Mlp(nn.Module):
  72. def __init__(self, config):
  73. super(Mlp, self).__init__()
  74. self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
  75. self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
  76. self.act_fn = ACT2FN["gelu"]
  77. self.dropout = Dropout(config.transformer["dropout_rate"])
  78. self._init_weights()
  79. def _init_weights(self):
  80. nn.init.xavier_uniform_(self.fc1.weight)
  81. nn.init.xavier_uniform_(self.fc2.weight)
  82. nn.init.normal_(self.fc1.bias, std=1e-6)
  83. nn.init.normal_(self.fc2.bias, std=1e-6)
  84. def forward(self, x):
  85. x = self.fc1(x)
  86. x = self.act_fn(x)
  87. x = self.dropout(x)
  88. x = self.fc2(x)
  89. x = self.dropout(x)
  90. return x
  91. class Embeddings(nn.Module):
  92. """Construct the embeddings from patch, position embeddings.
  93. """
  94. def __init__(self, config, img_size, in_channels=3):
  95. super(Embeddings, self).__init__()
  96. self.hybrid = None
  97. self.config = config
  98. img_size = _pair(img_size)
  99. if config.patches.get("grid") is not None: # ResNet
  100. grid_size = config.patches["grid"]
  101. patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
  102. patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
  103. n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
  104. self.hybrid = True
  105. else:
  106. patch_size = _pair(config.patches["size"])
  107. n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
  108. self.hybrid = False
  109. if self.hybrid:
  110. self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
  111. in_channels = self.hybrid_model.width * 16
  112. self.patch_embeddings = Conv2d(in_channels=in_channels,
  113. out_channels=config.hidden_size,
  114. kernel_size=patch_size,
  115. stride=patch_size)
  116. self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
  117. self.dropout = Dropout(config.transformer["dropout_rate"])
  118. def forward(self, x):
  119. if self.hybrid:
  120. x, features = self.hybrid_model(x)
  121. else:
  122. features = None
  123. x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
  124. x = x.flatten(2)
  125. x = x.transpose(-1, -2) # (B, n_patches, hidden)
  126. embeddings = x + self.position_embeddings
  127. embeddings = self.dropout(embeddings)
  128. return embeddings, features
  129. class Block(nn.Module):
  130. def __init__(self, config, vis):
  131. super(Block, self).__init__()
  132. self.hidden_size = config.hidden_size
  133. self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
  134. self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
  135. self.ffn = Mlp(config)
  136. self.attn = Attention(config, vis)
  137. def forward(self, x):
  138. h = x
  139. x = self.attention_norm(x)
  140. x, weights = self.attn(x)
  141. x = x + h
  142. h = x
  143. x = self.ffn_norm(x)
  144. x = self.ffn(x)
  145. x = x + h
  146. return x, weights
  147. def load_from(self, weights, n_block):
  148. ROOT = f"Transformer/encoderblock_{n_block}"
  149. with torch.no_grad():
  150. query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  151. key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  152. value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  153. out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  154. query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
  155. key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
  156. value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
  157. out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
  158. self.attn.query.weight.copy_(query_weight)
  159. self.attn.key.weight.copy_(key_weight)
  160. self.attn.value.weight.copy_(value_weight)
  161. self.attn.out.weight.copy_(out_weight)
  162. self.attn.query.bias.copy_(query_bias)
  163. self.attn.key.bias.copy_(key_bias)
  164. self.attn.value.bias.copy_(value_bias)
  165. self.attn.out.bias.copy_(out_bias)
  166. mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
  167. mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
  168. mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
  169. mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
  170. self.ffn.fc1.weight.copy_(mlp_weight_0)
  171. self.ffn.fc2.weight.copy_(mlp_weight_1)
  172. self.ffn.fc1.bias.copy_(mlp_bias_0)
  173. self.ffn.fc2.bias.copy_(mlp_bias_1)
  174. self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
  175. self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
  176. self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
  177. self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
  178. class Encoder(nn.Module):
  179. def __init__(self, config, vis):
  180. super(Encoder, self).__init__()
  181. self.vis = vis
  182. self.layer = nn.ModuleList()
  183. self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
  184. for _ in range(config.transformer["num_layers"]):
  185. layer = Block(config, vis)
  186. self.layer.append(copy.deepcopy(layer))
  187. def forward(self, hidden_states):
  188. attn_weights = []
  189. for layer_block in self.layer:
  190. hidden_states, weights = layer_block(hidden_states)
  191. if self.vis:
  192. attn_weights.append(weights)
  193. encoded = self.encoder_norm(hidden_states)
  194. return encoded, attn_weights
  195. class Transformer(nn.Module):
  196. def __init__(self, config, img_size, vis):
  197. super(Transformer, self).__init__()
  198. self.embeddings = Embeddings(config, img_size=img_size)
  199. self.encoder = Encoder(config, vis)
  200. def forward(self, input_ids):
  201. embedding_output, features = self.embeddings(input_ids)
  202. encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
  203. return encoded, attn_weights, features
  204. class Conv2dReLU(nn.Sequential):
  205. def __init__(
  206. self,
  207. in_channels,
  208. out_channels,
  209. kernel_size,
  210. padding=0,
  211. stride=1,
  212. use_batchnorm=True,
  213. ):
  214. conv = nn.Conv2d(
  215. in_channels,
  216. out_channels,
  217. kernel_size,
  218. stride=stride,
  219. padding=padding,
  220. bias=not (use_batchnorm),
  221. )
  222. relu = nn.ReLU(inplace=True)
  223. bn = nn.BatchNorm2d(out_channels)
  224. super(Conv2dReLU, self).__init__(conv, bn, relu)
  225. class DecoderBlock(nn.Module):
  226. def __init__(
  227. self,
  228. in_channels,
  229. out_channels,
  230. skip_channels=0,
  231. use_batchnorm=True,
  232. ):
  233. super().__init__()
  234. self.conv1 = Conv2dReLU(
  235. in_channels + skip_channels,
  236. out_channels,
  237. kernel_size=3,
  238. padding=1,
  239. use_batchnorm=use_batchnorm,
  240. )
  241. self.conv2 = Conv2dReLU(
  242. out_channels,
  243. out_channels,
  244. kernel_size=3,
  245. padding=1,
  246. use_batchnorm=use_batchnorm,
  247. )
  248. self.up = nn.UpsamplingBilinear2d(scale_factor=2)
  249. def forward(self, x, skip=None):
  250. x = self.up(x)
  251. if skip is not None:
  252. x = torch.cat([x, skip], dim=1)
  253. x = self.conv1(x)
  254. x = self.conv2(x)
  255. return x
  256. class SegmentationHead(nn.Sequential):
  257. def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
  258. conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
  259. upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
  260. super().__init__(conv2d, upsampling)
  261. class DecoderCup(nn.Module):
  262. def __init__(self, config):
  263. super().__init__()
  264. self.config = config
  265. head_channels = 512
  266. self.conv_more = Conv2dReLU(
  267. config.hidden_size,
  268. head_channels,
  269. kernel_size=3,
  270. padding=1,
  271. use_batchnorm=True,
  272. )
  273. decoder_channels = config.decoder_channels
  274. in_channels = [head_channels] + list(decoder_channels[:-1])
  275. out_channels = decoder_channels
  276. if self.config.n_skip != 0:
  277. skip_channels = self.config.skip_channels
  278. for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
  279. skip_channels[3-i]=0
  280. else:
  281. skip_channels=[0,0,0,0]
  282. blocks = [
  283. DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
  284. ]
  285. self.blocks = nn.ModuleList(blocks)
  286. def forward(self, hidden_states, features=None):
  287. B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
  288. h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
  289. x = hidden_states.permute(0, 2, 1)
  290. x = x.contiguous().view(B, hidden, h, w)
  291. x = self.conv_more(x)
  292. for i, decoder_block in enumerate(self.blocks):
  293. if features is not None:
  294. skip = features[i] if (i < self.config.n_skip) else None
  295. else:
  296. skip = None
  297. x = decoder_block(x, skip=skip)
  298. return x
  299. class VisionTransformer(nn.Module):
  300. def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
  301. super(VisionTransformer, self).__init__()
  302. self.num_classes = num_classes
  303. self.zero_head = zero_head
  304. self.classifier = config.classifier
  305. self.transformer = Transformer(config, img_size, vis)
  306. self.decoder = DecoderCup(config)
  307. self.segmentation_head = SegmentationHead(
  308. in_channels=config['decoder_channels'][-1],
  309. out_channels=config['n_classes'],
  310. kernel_size=3,
  311. )
  312. self.config = config
  313. def forward(self, x):
  314. if x.size()[1] == 1:
  315. x = x.repeat(1,3,1,1)
  316. x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
  317. x = self.decoder(x, features)
  318. logits = self.segmentation_head(x)
  319. return logits
  320. def load_from(self, weights):
  321. with torch.no_grad():
  322. res_weight = weights
  323. self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
  324. self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
  325. self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
  326. self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
  327. posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
  328. posemb_new = self.transformer.embeddings.position_embeddings
  329. if posemb.size() == posemb_new.size():
  330. self.transformer.embeddings.position_embeddings.copy_(posemb)
  331. elif posemb.size()[1]-1 == posemb_new.size()[1]:
  332. posemb = posemb[:, 1:]
  333. self.transformer.embeddings.position_embeddings.copy_(posemb)
  334. else:
  335. logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
  336. ntok_new = posemb_new.size(1)
  337. if self.classifier == "seg":
  338. _, posemb_grid = posemb[:, :1], posemb[0, 1:]
  339. gs_old = int(np.sqrt(len(posemb_grid)))
  340. gs_new = int(np.sqrt(ntok_new))
  341. print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
  342. posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
  343. zoom = (gs_new / gs_old, gs_new / gs_old, 1)
  344. posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
  345. posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
  346. posemb = posemb_grid
  347. self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
  348. # Encoder whole
  349. for bname, block in self.transformer.encoder.named_children():
  350. for uname, unit in block.named_children():
  351. unit.load_from(weights, n_block=uname)
  352. if self.transformer.embeddings.hybrid:
  353. self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
  354. gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
  355. gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
  356. self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
  357. self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
  358. for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
  359. for uname, unit in block.named_children():
  360. unit.load_from(res_weight, n_block=bname, n_unit=uname)
  361. CONFIGS = {
  362. 'ViT-B_16': configs.get_b16_config(),
  363. 'ViT-B_32': configs.get_b32_config(),
  364. 'ViT-L_16': configs.get_l16_config(),
  365. 'ViT-L_32': configs.get_l32_config(),
  366. 'ViT-H_14': configs.get_h14_config(),
  367. 'R50-ViT-B_16': configs.get_r50_b16_config(),
  368. 'R50-ViT-L_16': configs.get_r50_l16_config(),
  369. 'testing': configs.get_testing(),
  370. }