|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # Copyright 2021 The KubeEdge Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from __future__ import absolute_import
-
- import torch
- from torch import nn
- from torch.nn import functional as F
- from torch.nn import init
- from torch.autograd import Variable
- from torchvision.models import resnet50, resnet34
- import math
- import os
- import numpy as np
- from .MetaModules import *
-
-
- class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(Bottleneck, self).__init__()
- self.conv1 = MetaConv2d(inplanes, planes, kernel_size=1, bias=False)
- self.bn1 = MetaBatchNorm2d(planes)
- self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
- self.bn2 = MetaBatchNorm2d(planes)
- self.conv3 = MetaConv2d(planes, planes * 4, kernel_size=1, bias=False)
- self.bn3 = MetaBatchNorm2d(planes * 4)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
- class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(BasicBlock, self).__init__()
-
- # Both self.conv1 and self.downsample layers downsample the input when stride != 1
- self.conv1 = MetaConv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
- self.bn1 = MetaBatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn2 = MetaBatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
-
- class MetaResNetBase(MetaModule):
- def __init__(self, layers, block=Bottleneck):
- super(MetaResNetBase, self).__init__()
- self.inplanes = 64
- self.conv1 = MetaConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
- self.bn1 = MetaBatchNorm2d(64)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
-
- def _make_layer(self, block, planes, blocks, stride=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- MetaConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
- MetaBatchNorm2d(planes * block.expansion),
- )
-
- layers = [
- block(self.inplanes, planes, stride, downsample)
- ]
-
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes))
-
- return nn.Sequential(*layers)
-
- def forward(self, x, MTE=False):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- return x
-
-
- class MetaResNet(MetaModule):
- def __init_with_imagenet(self, baseModel):
- model = resnet50(pretrained=False)
- del model.fc
- baseModel.copyWeight(model.state_dict())
-
- def getBase(self):
- baseModel = MetaResNetBase([3, 4, 6, 3])
- self.__init_with_imagenet(baseModel)
- return baseModel
-
- def __init__(self, num_features=0, dropout=0, cut_at_pooling=False, norm=True, num_classes=[0,0,0], BNNeck=False):
- super(MetaResNet, self).__init__()
- self.num_features = num_features
- self.dropout = dropout
- self.cut_at_pooling = cut_at_pooling
- self.num_classes1 = num_classes[0]
- self.num_classes2 = num_classes[1]
- self.num_classes3 = num_classes[2]
- self.has_embedding = num_features > 0
- self.norm = norm
- self.BNNeck = BNNeck
- if self.dropout > 0:
- self.drop = nn.Dropout(self.dropout)
- # Construct base (pretrained) resnet
- self.base = self.getBase()
- self.base.layer4[0].conv2.stride = (1, 1)
- self.base.layer4[0].downsample[0].stride = (1, 1)
- self.gap = nn.AdaptiveAvgPool2d(1)
- out_planes = 2048
- if self.has_embedding:
- self.feat = MetaLinear(out_planes, self.num_features)
- init.kaiming_normal_(self.feat.weight, mode='fan_out')
- init.constant_(self.feat.bias, 0)
- else:
- # Change the num_features to CNN output channels
- self.num_features = out_planes
-
- self.feat_bn = MixUpBatchNorm1d(self.num_features)
- init.constant_(self.feat_bn.weight, 1)
- init.constant_(self.feat_bn.bias, 0)
-
- def forward(self, x, MTE='', save_index=0):
- x= self.base(x)
- x = self.gap(x)
- x = x.view(x.size(0), -1)
-
- if self.cut_at_pooling:
- return x
-
- if self.has_embedding:
- bn_x = self.feat_bn(self.feat(x))
- else:
- bn_x = self.feat_bn(x, MTE, save_index)
- tri_features = x
-
- if self.training is False:
- bn_x = F.normalize(bn_x)
- return bn_x
-
- if isinstance(bn_x, list):
- output = []
- for bnfeature in bn_x:
- if self.norm:
- bnfeature = F.normalize(bnfeature)
- output.append(bnfeature)
- if self.BNNeck:
- return output, tri_features
- else:
- return output
-
- if self.norm:
- bn_x = F.normalize(bn_x)
- elif self.has_embedding:
- bn_x = F.relu(bn_x)
-
- if self.dropout > 0:
- bn_x = self.drop(bn_x)
-
- if self.BNNeck:
- return bn_x, tri_features
- else:
- return bn_x
-
-
|