|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 |
- # Copyright 2021 The KubeEdge Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from __future__ import absolute_import
-
- import torch
- from torch import nn
- from torch.nn import functional as F
- from torch.nn import init
- from torch.autograd import Variable
- import math
- import numpy as np
-
-
- def to_var(x, requires_grad=True):
- if torch.cuda.is_available(): x = x.cuda()
- return Variable(x, requires_grad=requires_grad)
-
-
- class MetaModule(nn.Module):
- # adopted from: Adrien Ecoffet https://github.com/AdrienLE
- def params(self):
- for name, param in self.named_params(self):
- yield param
-
- def named_leaves(self):
- return []
-
- def named_submodules(self):
- return []
-
- def named_params(self, curr_module=None, memo=None, prefix=''):
- if memo is None:
- memo = set()
-
- if hasattr(curr_module, 'named_leaves'):
- for name, p in curr_module.named_leaves():
- if p is not None and p not in memo:
- memo.add(p)
- yield prefix + ('.' if prefix else '') + name, p
- else:
- for name, p in curr_module._parameters.items():
- if p is not None and p not in memo:
- memo.add(p)
- yield prefix + ('.' if prefix else '') + name, p
-
- for mname, module in curr_module.named_children():
- submodule_prefix = prefix + ('.' if prefix else '') + mname
- for name, p in self.named_params(module, memo, submodule_prefix):
- yield name, p
-
- def update_params(self, lr_inner, source_params=None,
- solver='sgd', beta1=0.9, beta2=0.999, weight_decay=5e-4):
- if solver == 'sgd':
- for tgt, src in zip(self.named_params(self), source_params):
- name_t, param_t = tgt
- grad = src if src is not None else 0
- tmp = param_t - lr_inner * grad
- self.set_param(self, name_t, tmp)
- elif solver == 'adam':
- for tgt, gradVal in zip(self.named_params(self), source_params):
- name_t, param_t = tgt
- exp_avg, exp_avg_sq = torch.zeros_like(param_t.data), \
- torch.zeros_like(param_t.data)
- bias_correction1 = 1 - beta1
- bias_correction2 = 1 - beta2
- gradVal.add_(weight_decay, param_t)
- exp_avg.mul_(beta1).add_(1 - beta1, gradVal)
- exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, gradVal, gradVal)
- exp_avg_sq.add_(1e-8) # to avoid possible nan in backward
- denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8)
- step_size = lr_inner / bias_correction1
- newParam = param_t.addcdiv(-step_size, exp_avg, denom)
- self.set_param(self, name_t, newParam)
-
- def setParams(self, params):
- for tgt, param in zip(self.named_params(self), params):
- name_t, _ = tgt
- self.set_param(self, name_t, param)
-
- def set_param(self, curr_mod, name, param):
- if '.' in name:
- n = name.split('.')
- module_name = n[0]
- rest = '.'.join(n[1:])
- for name, mod in curr_mod.named_children():
- if module_name == name:
- self.set_param(mod, rest, param)
- break
- else:
- setattr(curr_mod, name, param)
-
- def setBN(self, inPart, name, param):
- if '.' in name:
- part = name.split('.')
- self.setBN(getattr(inPart, part[0]), '.'.join(part[1:]), param)
- else:
- setattr(inPart, name, param)
-
- def detach_params(self):
- for name, param in self.named_params(self):
- self.set_param(self, name, param.detach())
-
- def copyModel(self, newModel, same_var=False):
- # copy meta model to meta model
- tarName = list(map(lambda v: v, newModel.state_dict().keys()))
-
- # requires_grad
- partName, partW = list(map(lambda v: v[0], newModel.named_params(newModel))), list(
- map(lambda v: v[1], newModel.named_params(newModel))) # new model's weight
-
- metaName, metaW = list(map(lambda v: v[0], self.named_params(self))), list(
- map(lambda v: v[1], self.named_params(self)))
- bnNames = list(set(tarName) - set(partName))
-
- # copy vars
- for name, param in zip(metaName, partW):
- if not same_var:
- param = to_var(param.data.clone(), requires_grad=True)
- self.set_param(self, name, param)
- # copy training mean var
- tarName = newModel.state_dict()
- for name in bnNames:
- param = to_var(tarName[name], requires_grad=False)
- self.setBN(self, name, param)
-
- def copyWeight(self, modelW):
- # copy state_dict to buffers
- curName = list(map(lambda v: v[0], self.named_params(self)))
- tarNames = set()
- for name in modelW.keys():
- # print(name)
- if name.startswith("module"):
- tarNames.add(".".join(name.split(".")[1:]))
- else:
- tarNames.add(name)
- # bnNames = list(tarNames - set(curName))
- for tgt in self.named_params(self):
- name_t, param_t = tgt
- # print(name_t)
- module_name_t = 'module.' + name_t
- if name_t in modelW:
- param = to_var(modelW[name_t], requires_grad=True)
- self.set_param(self, name_t, param)
- elif module_name_t in modelW:
- param = to_var(modelW['module.' + name_t], requires_grad=True)
- self.set_param(self, name_t, param)
- else:
- continue
-
-
- def load_param(self, path):
- modelW = torch.load(path, map_location=torch.device('cpu'))['state_dict']
- print("=> Loaded M3L ReID model '{}'".format(path))
-
- # copy state_dict to buffers
- curName = list(map(lambda v: v[0], self.named_params(self)))
- tarNames = set()
- for name in modelW.keys():
- # print(name)
- if name.startswith("module"):
- tarNames.add(".".join(name.split(".")[1:]))
- else:
- tarNames.add(name)
- bnNames = list(tarNames - set(curName)) ## in BN resMeta bnNames only contains running var/mean
- for tgt in self.named_params(self):
- name_t, param_t = tgt
- # print(name_t)
- module_name_t = 'module.' + name_t
- if name_t in modelW:
- param = to_var(modelW[name_t], requires_grad=True)
- self.set_param(self, name_t, param)
- elif module_name_t in modelW:
- param = to_var(modelW['module.' + name_t], requires_grad=True)
- self.set_param(self, name_t, param)
- else:
- continue
-
- for name in bnNames:
- try:
- param = to_var(modelW[name], requires_grad=False)
- except:
- param = to_var(modelW['module.' + name], requires_grad=False)
- self.setBN(self, name, param)
-
-
-
-
- class MetaConv2d(MetaModule):
- def __init__(self, *args, **kwargs):
- super().__init__()
- ignore = nn.Conv2d(*args, **kwargs)
- self.in_channels = ignore.in_channels
- self.out_channels = ignore.out_channels
- self.stride = ignore.stride
- self.padding = ignore.padding
- self.dilation = ignore.dilation
- self.groups = ignore.groups
- self.kernel_size = ignore.kernel_size
-
- self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
-
- if ignore.bias is not None:
- self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
- else:
- self.register_buffer('bias', None)
-
- def forward(self, x):
- return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
-
- def named_leaves(self):
- return [('weight', self.weight), ('bias', self.bias)]
-
-
- class MetaBatchNorm2d(MetaModule):
- def __init__(self, *args, **kwargs):
- super().__init__()
- ignore = nn.BatchNorm2d(*args, **kwargs)
-
- self.num_features = ignore.num_features
- self.eps = ignore.eps
- self.momentum = ignore.momentum
- self.affine = ignore.affine
- self.track_running_stats = ignore.track_running_stats
-
- if self.affine:
- self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
- self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
-
- if self.track_running_stats:
- self.register_buffer('running_mean', torch.zeros(self.num_features))
- self.register_buffer('running_var', torch.ones(self.num_features))
- self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze())
- else:
- self.register_buffer('running_mean', None)
- self.register_buffer('running_var', None)
- self.register_buffer('num_batches_tracked', None)
-
- def forward(self, x):
- val2 = self.weight.sum()
- res = F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
- self.training or not self.track_running_stats, self.momentum, self.eps)
- return res
-
- def named_leaves(self):
- return [('weight', self.weight), ('bias', self.bias)]
-
-
- class MetaBatchNorm1d(MetaModule):
- def __init__(self, *args, **kwargs):
- super().__init__()
- ignore = nn.BatchNorm1d(*args, **kwargs)
-
- self.num_features = ignore.num_features
- self.eps = ignore.eps
- self.momentum = ignore.momentum
- self.affine = ignore.affine
- self.track_running_stats = ignore.track_running_stats
-
- if self.affine:
- self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
- self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
-
- if self.track_running_stats:
- self.register_buffer('running_mean', torch.zeros(self.num_features))
- self.register_buffer('running_var', torch.ones(self.num_features))
- self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze())
- else:
- self.register_buffer('running_mean', None)
- self.register_buffer('running_var', None)
- self.register_buffer('num_batches_tracked', None)
-
- def forward(self, x):
- return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
- self.training or not self.track_running_stats, self.momentum, self.eps)
- ## meta test set this one to False self.training or not self.track_running_stats
- def named_leaves(self):
- return [('weight', self.weight), ('bias', self.bias)]
-
-
- class MetaInstanceNorm2d(MetaModule):
- def __init__(self, *args, **kwargs):
- super().__init__()
- ignore = nn.InstanceNorm2d(*args, **kwargs)
-
- self.num_features = ignore.num_features
- self.eps = ignore.eps
- self.momentum = ignore.momentum
- self.affine = ignore.affine
- self.track_running_stats = ignore.track_running_stats
-
- if self.affine:
- self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
- self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
- else:
- self.register_buffer('weight', None)
- self.register_buffer('bias', None)
-
- if self.track_running_stats:
- self.register_buffer('running_mean', torch.zeros(self.num_features))
- self.register_buffer('running_var', torch.ones(self.num_features))
- self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze())
- else:
- self.register_buffer('running_mean', None)
- self.register_buffer('running_var', None)
- self.register_buffer('num_batches_tracked', None)
-
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- if self.affine:
- init.constant_(self.weight, 1)
- init.constant_(self.bias, 0)
-
- def forward(self, x):
-
- res = F.instance_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
- self.training or not self.track_running_stats, self.momentum, self.eps)
- return res
-
- def named_leaves(self):
- return [('weight', self.weight), ('bias', self.bias)]
-
- class MixUpBatchNorm1d(MetaBatchNorm1d):
- def __init__(self, num_features, eps=1e-5, momentum=0.1,
- affine=True, track_running_stats=True):
- super(MixUpBatchNorm1d, self).__init__(
- num_features, eps, momentum, affine, track_running_stats)
-
- self.register_buffer('meta_mean1', torch.zeros(self.num_features))
- self.register_buffer('meta_var1', torch.zeros(self.num_features))
- self.register_buffer('meta_mean2', torch.zeros(self.num_features))
- self.register_buffer('meta_var2', torch.zeros(self.num_features))
- self.device_count = torch.cuda.device_count()
-
- def forward(self, input, MTE='', save_index=0):
- exponential_average_factor = 0.0
-
- if self.training and self.track_running_stats:
- if self.num_batches_tracked is not None:
- self.num_batches_tracked += 1
- if self.momentum is None: # use cumulative moving average
- exponential_average_factor = 1.0 / float(self.num_batches_tracked)
- else: # use exponential moving average
- exponential_average_factor = self.momentum
-
- # calculate running estimates
- if self.training:
- if MTE == 'sample':
- from torch.distributions.normal import Normal
- Distri1 = Normal(self.meta_mean1, self.meta_var1)
- Distri2 = Normal(self.meta_mean2, self.meta_var2)
- sample1 = Distri1.sample([input.size(0), ])
- sample2 = Distri2.sample([input.size(0), ])
- lam = np.random.beta(1., 1.)
- inputmix1 = lam * sample1 + (1-lam) * input
- inputmix2 = lam * sample2 + (1-lam) * input
-
- mean1 = inputmix1.mean(dim=0)
- var1 = inputmix1.var(dim=0, unbiased=False)
- mean2 = inputmix2.mean(dim=0)
- var2 = inputmix2.var(dim=0, unbiased=False)
-
- output1 = (inputmix1 - mean1[None, :]) / (torch.sqrt(var1[None, :] + self.eps))
- output2 = (inputmix2 - mean2[None, :]) / (torch.sqrt(var2[None, :] + self.eps))
- if self.affine:
- output1 = output1 * self.weight[None, :] + self.bias[None, :]
- output2 = output2 * self.weight[None, :] + self.bias[None, :]
- return [output1, output2]
-
- else:
- mean = input.mean(dim=0)
- # use biased var in train
- var = input.var(dim=0, unbiased=False)
- n = input.numel() / input.size(1)
-
- with torch.no_grad():
- running_mean = exponential_average_factor * mean \
- + (1 - exponential_average_factor) * self.running_mean
- # update running_var with unbiased var
- running_var = exponential_average_factor * var * n / (n - 1) \
- + (1 - exponential_average_factor) * self.running_var
- self.running_mean.copy_(running_mean)
- self.running_var.copy_(running_var)
- if save_index == 1:
- self.meta_mean1.copy_(mean)
- self.meta_var1.copy_(var)
- elif save_index == 2:
- self.meta_mean2.copy_(mean)
- self.meta_var2.copy_(var)
-
- else:
- mean = self.running_mean
- var = self.running_var
-
- input = (input - mean[None, :]) / (torch.sqrt(var[None, :] + self.eps))
- if self.affine:
- input = input * self.weight[None, :] + self.bias[None, :]
-
- return input
|