You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resMeta.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import torch
  16. from torch import nn
  17. from torch.nn import functional as F
  18. from torch.nn import init
  19. from torch.autograd import Variable
  20. from torchvision.models import resnet50, resnet34
  21. import math
  22. import os
  23. import numpy as np
  24. from .MetaModules import *
  25. class Bottleneck(nn.Module):
  26. expansion = 4
  27. def __init__(self, inplanes, planes, stride=1, downsample=None):
  28. super(Bottleneck, self).__init__()
  29. self.conv1 = MetaConv2d(inplanes, planes, kernel_size=1, bias=False)
  30. self.bn1 = MetaBatchNorm2d(planes)
  31. self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  32. self.bn2 = MetaBatchNorm2d(planes)
  33. self.conv3 = MetaConv2d(planes, planes * 4, kernel_size=1, bias=False)
  34. self.bn3 = MetaBatchNorm2d(planes * 4)
  35. self.relu = nn.ReLU(inplace=True)
  36. self.downsample = downsample
  37. self.stride = stride
  38. def forward(self, x):
  39. residual = x
  40. out = self.conv1(x)
  41. out = self.bn1(out)
  42. out = self.relu(out)
  43. out = self.conv2(out)
  44. out = self.bn2(out)
  45. out = self.relu(out)
  46. out = self.conv3(out)
  47. out = self.bn3(out)
  48. if self.downsample is not None:
  49. residual = self.downsample(x)
  50. out += residual
  51. out = self.relu(out)
  52. return out
  53. class BasicBlock(nn.Module):
  54. expansion = 1
  55. def __init__(self, inplanes, planes, stride=1, downsample=None):
  56. super(BasicBlock, self).__init__()
  57. # Both self.conv1 and self.downsample layers downsample the input when stride != 1
  58. self.conv1 = MetaConv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  59. self.bn1 = MetaBatchNorm2d(planes)
  60. self.relu = nn.ReLU(inplace=True)
  61. self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  62. self.bn2 = MetaBatchNorm2d(planes)
  63. self.downsample = downsample
  64. self.stride = stride
  65. def forward(self, x):
  66. identity = x
  67. out = self.conv1(x)
  68. out = self.bn1(out)
  69. out = self.relu(out)
  70. out = self.conv2(out)
  71. out = self.bn2(out)
  72. if self.downsample is not None:
  73. identity = self.downsample(x)
  74. out += identity
  75. out = self.relu(out)
  76. return out
  77. class MetaResNetBase(MetaModule):
  78. def __init__(self, layers, block=Bottleneck):
  79. super(MetaResNetBase, self).__init__()
  80. self.inplanes = 64
  81. self.conv1 = MetaConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
  82. self.bn1 = MetaBatchNorm2d(64)
  83. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  84. self.layer1 = self._make_layer(block, 64, layers[0])
  85. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  86. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  87. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  88. def _make_layer(self, block, planes, blocks, stride=1):
  89. downsample = None
  90. if stride != 1 or self.inplanes != planes * block.expansion:
  91. downsample = nn.Sequential(
  92. MetaConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
  93. MetaBatchNorm2d(planes * block.expansion),
  94. )
  95. layers = [
  96. block(self.inplanes, planes, stride, downsample)
  97. ]
  98. self.inplanes = planes * block.expansion
  99. for i in range(1, blocks):
  100. layers.append(block(self.inplanes, planes))
  101. return nn.Sequential(*layers)
  102. def forward(self, x, MTE=False):
  103. x = self.conv1(x)
  104. x = self.bn1(x)
  105. x = self.maxpool(x)
  106. x = self.layer1(x)
  107. x = self.layer2(x)
  108. x = self.layer3(x)
  109. x = self.layer4(x)
  110. return x
  111. class MetaResNet(MetaModule):
  112. def __init_with_imagenet(self, baseModel):
  113. model = resnet50(pretrained=False)
  114. del model.fc
  115. baseModel.copyWeight(model.state_dict())
  116. def getBase(self):
  117. baseModel = MetaResNetBase([3, 4, 6, 3])
  118. self.__init_with_imagenet(baseModel)
  119. return baseModel
  120. def __init__(self, num_features=0, dropout=0, cut_at_pooling=False, norm=True, num_classes=[0,0,0], BNNeck=False):
  121. super(MetaResNet, self).__init__()
  122. self.num_features = num_features
  123. self.dropout = dropout
  124. self.cut_at_pooling = cut_at_pooling
  125. self.num_classes1 = num_classes[0]
  126. self.num_classes2 = num_classes[1]
  127. self.num_classes3 = num_classes[2]
  128. self.has_embedding = num_features > 0
  129. self.norm = norm
  130. self.BNNeck = BNNeck
  131. if self.dropout > 0:
  132. self.drop = nn.Dropout(self.dropout)
  133. # Construct base (pretrained) resnet
  134. self.base = self.getBase()
  135. self.base.layer4[0].conv2.stride = (1, 1)
  136. self.base.layer4[0].downsample[0].stride = (1, 1)
  137. self.gap = nn.AdaptiveAvgPool2d(1)
  138. out_planes = 2048
  139. if self.has_embedding:
  140. self.feat = MetaLinear(out_planes, self.num_features)
  141. init.kaiming_normal_(self.feat.weight, mode='fan_out')
  142. init.constant_(self.feat.bias, 0)
  143. else:
  144. # Change the num_features to CNN output channels
  145. self.num_features = out_planes
  146. self.feat_bn = MixUpBatchNorm1d(self.num_features)
  147. init.constant_(self.feat_bn.weight, 1)
  148. init.constant_(self.feat_bn.bias, 0)
  149. def forward(self, x, MTE='', save_index=0):
  150. x= self.base(x)
  151. x = self.gap(x)
  152. x = x.view(x.size(0), -1)
  153. if self.cut_at_pooling:
  154. return x
  155. if self.has_embedding:
  156. bn_x = self.feat_bn(self.feat(x))
  157. else:
  158. bn_x = self.feat_bn(x, MTE, save_index)
  159. tri_features = x
  160. if self.training is False:
  161. bn_x = F.normalize(bn_x)
  162. return bn_x
  163. if isinstance(bn_x, list):
  164. output = []
  165. for bnfeature in bn_x:
  166. if self.norm:
  167. bnfeature = F.normalize(bnfeature)
  168. output.append(bnfeature)
  169. if self.BNNeck:
  170. return output, tri_features
  171. else:
  172. return output
  173. if self.norm:
  174. bn_x = F.normalize(bn_x)
  175. elif self.has_embedding:
  176. bn_x = F.relu(bn_x)
  177. if self.dropout > 0:
  178. bn_x = self.drop(bn_x)
  179. if self.BNNeck:
  180. return bn_x, tri_features
  181. else:
  182. return bn_x