You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetaModules.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import torch
  16. from torch import nn
  17. from torch.nn import functional as F
  18. from torch.nn import init
  19. from torch.autograd import Variable
  20. import math
  21. import numpy as np
  22. def to_var(x, requires_grad=True):
  23. if torch.cuda.is_available(): x = x.cuda()
  24. return Variable(x, requires_grad=requires_grad)
  25. class MetaModule(nn.Module):
  26. # adopted from: Adrien Ecoffet https://github.com/AdrienLE
  27. def params(self):
  28. for name, param in self.named_params(self):
  29. yield param
  30. def named_leaves(self):
  31. return []
  32. def named_submodules(self):
  33. return []
  34. def named_params(self, curr_module=None, memo=None, prefix=''):
  35. if memo is None:
  36. memo = set()
  37. if hasattr(curr_module, 'named_leaves'):
  38. for name, p in curr_module.named_leaves():
  39. if p is not None and p not in memo:
  40. memo.add(p)
  41. yield prefix + ('.' if prefix else '') + name, p
  42. else:
  43. for name, p in curr_module._parameters.items():
  44. if p is not None and p not in memo:
  45. memo.add(p)
  46. yield prefix + ('.' if prefix else '') + name, p
  47. for mname, module in curr_module.named_children():
  48. submodule_prefix = prefix + ('.' if prefix else '') + mname
  49. for name, p in self.named_params(module, memo, submodule_prefix):
  50. yield name, p
  51. def update_params(self, lr_inner, source_params=None,
  52. solver='sgd', beta1=0.9, beta2=0.999, weight_decay=5e-4):
  53. if solver == 'sgd':
  54. for tgt, src in zip(self.named_params(self), source_params):
  55. name_t, param_t = tgt
  56. grad = src if src is not None else 0
  57. tmp = param_t - lr_inner * grad
  58. self.set_param(self, name_t, tmp)
  59. elif solver == 'adam':
  60. for tgt, gradVal in zip(self.named_params(self), source_params):
  61. name_t, param_t = tgt
  62. exp_avg, exp_avg_sq = torch.zeros_like(param_t.data), \
  63. torch.zeros_like(param_t.data)
  64. bias_correction1 = 1 - beta1
  65. bias_correction2 = 1 - beta2
  66. gradVal.add_(weight_decay, param_t)
  67. exp_avg.mul_(beta1).add_(1 - beta1, gradVal)
  68. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, gradVal, gradVal)
  69. exp_avg_sq.add_(1e-8) # to avoid possible nan in backward
  70. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8)
  71. step_size = lr_inner / bias_correction1
  72. newParam = param_t.addcdiv(-step_size, exp_avg, denom)
  73. self.set_param(self, name_t, newParam)
  74. def setParams(self, params):
  75. for tgt, param in zip(self.named_params(self), params):
  76. name_t, _ = tgt
  77. self.set_param(self, name_t, param)
  78. def set_param(self, curr_mod, name, param):
  79. if '.' in name:
  80. n = name.split('.')
  81. module_name = n[0]
  82. rest = '.'.join(n[1:])
  83. for name, mod in curr_mod.named_children():
  84. if module_name == name:
  85. self.set_param(mod, rest, param)
  86. break
  87. else:
  88. setattr(curr_mod, name, param)
  89. def setBN(self, inPart, name, param):
  90. if '.' in name:
  91. part = name.split('.')
  92. self.setBN(getattr(inPart, part[0]), '.'.join(part[1:]), param)
  93. else:
  94. setattr(inPart, name, param)
  95. def detach_params(self):
  96. for name, param in self.named_params(self):
  97. self.set_param(self, name, param.detach())
  98. def copyModel(self, newModel, same_var=False):
  99. # copy meta model to meta model
  100. tarName = list(map(lambda v: v, newModel.state_dict().keys()))
  101. # requires_grad
  102. partName, partW = list(map(lambda v: v[0], newModel.named_params(newModel))), list(
  103. map(lambda v: v[1], newModel.named_params(newModel))) # new model's weight
  104. metaName, metaW = list(map(lambda v: v[0], self.named_params(self))), list(
  105. map(lambda v: v[1], self.named_params(self)))
  106. bnNames = list(set(tarName) - set(partName))
  107. # copy vars
  108. for name, param in zip(metaName, partW):
  109. if not same_var:
  110. param = to_var(param.data.clone(), requires_grad=True)
  111. self.set_param(self, name, param)
  112. # copy training mean var
  113. tarName = newModel.state_dict()
  114. for name in bnNames:
  115. param = to_var(tarName[name], requires_grad=False)
  116. self.setBN(self, name, param)
  117. def copyWeight(self, modelW):
  118. # copy state_dict to buffers
  119. curName = list(map(lambda v: v[0], self.named_params(self)))
  120. tarNames = set()
  121. for name in modelW.keys():
  122. # print(name)
  123. if name.startswith("module"):
  124. tarNames.add(".".join(name.split(".")[1:]))
  125. else:
  126. tarNames.add(name)
  127. # bnNames = list(tarNames - set(curName))
  128. for tgt in self.named_params(self):
  129. name_t, param_t = tgt
  130. # print(name_t)
  131. module_name_t = 'module.' + name_t
  132. if name_t in modelW:
  133. param = to_var(modelW[name_t], requires_grad=True)
  134. self.set_param(self, name_t, param)
  135. elif module_name_t in modelW:
  136. param = to_var(modelW['module.' + name_t], requires_grad=True)
  137. self.set_param(self, name_t, param)
  138. else:
  139. continue
  140. def load_param(self, path):
  141. modelW = torch.load(path, map_location=torch.device('cpu'))['state_dict']
  142. print("=> Loaded M3L ReID model '{}'".format(path))
  143. # copy state_dict to buffers
  144. curName = list(map(lambda v: v[0], self.named_params(self)))
  145. tarNames = set()
  146. for name in modelW.keys():
  147. # print(name)
  148. if name.startswith("module"):
  149. tarNames.add(".".join(name.split(".")[1:]))
  150. else:
  151. tarNames.add(name)
  152. bnNames = list(tarNames - set(curName)) ## in BN resMeta bnNames only contains running var/mean
  153. for tgt in self.named_params(self):
  154. name_t, param_t = tgt
  155. # print(name_t)
  156. module_name_t = 'module.' + name_t
  157. if name_t in modelW:
  158. param = to_var(modelW[name_t], requires_grad=True)
  159. self.set_param(self, name_t, param)
  160. elif module_name_t in modelW:
  161. param = to_var(modelW['module.' + name_t], requires_grad=True)
  162. self.set_param(self, name_t, param)
  163. else:
  164. continue
  165. for name in bnNames:
  166. try:
  167. param = to_var(modelW[name], requires_grad=False)
  168. except:
  169. param = to_var(modelW['module.' + name], requires_grad=False)
  170. self.setBN(self, name, param)
  171. class MetaConv2d(MetaModule):
  172. def __init__(self, *args, **kwargs):
  173. super().__init__()
  174. ignore = nn.Conv2d(*args, **kwargs)
  175. self.in_channels = ignore.in_channels
  176. self.out_channels = ignore.out_channels
  177. self.stride = ignore.stride
  178. self.padding = ignore.padding
  179. self.dilation = ignore.dilation
  180. self.groups = ignore.groups
  181. self.kernel_size = ignore.kernel_size
  182. self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
  183. if ignore.bias is not None:
  184. self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
  185. else:
  186. self.register_buffer('bias', None)
  187. def forward(self, x):
  188. return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  189. def named_leaves(self):
  190. return [('weight', self.weight), ('bias', self.bias)]
  191. class MetaBatchNorm2d(MetaModule):
  192. def __init__(self, *args, **kwargs):
  193. super().__init__()
  194. ignore = nn.BatchNorm2d(*args, **kwargs)
  195. self.num_features = ignore.num_features
  196. self.eps = ignore.eps
  197. self.momentum = ignore.momentum
  198. self.affine = ignore.affine
  199. self.track_running_stats = ignore.track_running_stats
  200. if self.affine:
  201. self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
  202. self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
  203. if self.track_running_stats:
  204. self.register_buffer('running_mean', torch.zeros(self.num_features))
  205. self.register_buffer('running_var', torch.ones(self.num_features))
  206. self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze())
  207. else:
  208. self.register_buffer('running_mean', None)
  209. self.register_buffer('running_var', None)
  210. self.register_buffer('num_batches_tracked', None)
  211. def forward(self, x):
  212. val2 = self.weight.sum()
  213. res = F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
  214. self.training or not self.track_running_stats, self.momentum, self.eps)
  215. return res
  216. def named_leaves(self):
  217. return [('weight', self.weight), ('bias', self.bias)]
  218. class MetaBatchNorm1d(MetaModule):
  219. def __init__(self, *args, **kwargs):
  220. super().__init__()
  221. ignore = nn.BatchNorm1d(*args, **kwargs)
  222. self.num_features = ignore.num_features
  223. self.eps = ignore.eps
  224. self.momentum = ignore.momentum
  225. self.affine = ignore.affine
  226. self.track_running_stats = ignore.track_running_stats
  227. if self.affine:
  228. self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
  229. self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
  230. if self.track_running_stats:
  231. self.register_buffer('running_mean', torch.zeros(self.num_features))
  232. self.register_buffer('running_var', torch.ones(self.num_features))
  233. self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze())
  234. else:
  235. self.register_buffer('running_mean', None)
  236. self.register_buffer('running_var', None)
  237. self.register_buffer('num_batches_tracked', None)
  238. def forward(self, x):
  239. return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
  240. self.training or not self.track_running_stats, self.momentum, self.eps)
  241. ## meta test set this one to False self.training or not self.track_running_stats
  242. def named_leaves(self):
  243. return [('weight', self.weight), ('bias', self.bias)]
  244. class MetaInstanceNorm2d(MetaModule):
  245. def __init__(self, *args, **kwargs):
  246. super().__init__()
  247. ignore = nn.InstanceNorm2d(*args, **kwargs)
  248. self.num_features = ignore.num_features
  249. self.eps = ignore.eps
  250. self.momentum = ignore.momentum
  251. self.affine = ignore.affine
  252. self.track_running_stats = ignore.track_running_stats
  253. if self.affine:
  254. self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
  255. self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
  256. else:
  257. self.register_buffer('weight', None)
  258. self.register_buffer('bias', None)
  259. if self.track_running_stats:
  260. self.register_buffer('running_mean', torch.zeros(self.num_features))
  261. self.register_buffer('running_var', torch.ones(self.num_features))
  262. self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze())
  263. else:
  264. self.register_buffer('running_mean', None)
  265. self.register_buffer('running_var', None)
  266. self.register_buffer('num_batches_tracked', None)
  267. self.reset_parameters()
  268. def reset_parameters(self) -> None:
  269. if self.affine:
  270. init.constant_(self.weight, 1)
  271. init.constant_(self.bias, 0)
  272. def forward(self, x):
  273. res = F.instance_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
  274. self.training or not self.track_running_stats, self.momentum, self.eps)
  275. return res
  276. def named_leaves(self):
  277. return [('weight', self.weight), ('bias', self.bias)]
  278. class MixUpBatchNorm1d(MetaBatchNorm1d):
  279. def __init__(self, num_features, eps=1e-5, momentum=0.1,
  280. affine=True, track_running_stats=True):
  281. super(MixUpBatchNorm1d, self).__init__(
  282. num_features, eps, momentum, affine, track_running_stats)
  283. self.register_buffer('meta_mean1', torch.zeros(self.num_features))
  284. self.register_buffer('meta_var1', torch.zeros(self.num_features))
  285. self.register_buffer('meta_mean2', torch.zeros(self.num_features))
  286. self.register_buffer('meta_var2', torch.zeros(self.num_features))
  287. self.device_count = torch.cuda.device_count()
  288. def forward(self, input, MTE='', save_index=0):
  289. exponential_average_factor = 0.0
  290. if self.training and self.track_running_stats:
  291. if self.num_batches_tracked is not None:
  292. self.num_batches_tracked += 1
  293. if self.momentum is None: # use cumulative moving average
  294. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  295. else: # use exponential moving average
  296. exponential_average_factor = self.momentum
  297. # calculate running estimates
  298. if self.training:
  299. if MTE == 'sample':
  300. from torch.distributions.normal import Normal
  301. Distri1 = Normal(self.meta_mean1, self.meta_var1)
  302. Distri2 = Normal(self.meta_mean2, self.meta_var2)
  303. sample1 = Distri1.sample([input.size(0), ])
  304. sample2 = Distri2.sample([input.size(0), ])
  305. lam = np.random.beta(1., 1.)
  306. inputmix1 = lam * sample1 + (1-lam) * input
  307. inputmix2 = lam * sample2 + (1-lam) * input
  308. mean1 = inputmix1.mean(dim=0)
  309. var1 = inputmix1.var(dim=0, unbiased=False)
  310. mean2 = inputmix2.mean(dim=0)
  311. var2 = inputmix2.var(dim=0, unbiased=False)
  312. output1 = (inputmix1 - mean1[None, :]) / (torch.sqrt(var1[None, :] + self.eps))
  313. output2 = (inputmix2 - mean2[None, :]) / (torch.sqrt(var2[None, :] + self.eps))
  314. if self.affine:
  315. output1 = output1 * self.weight[None, :] + self.bias[None, :]
  316. output2 = output2 * self.weight[None, :] + self.bias[None, :]
  317. return [output1, output2]
  318. else:
  319. mean = input.mean(dim=0)
  320. # use biased var in train
  321. var = input.var(dim=0, unbiased=False)
  322. n = input.numel() / input.size(1)
  323. with torch.no_grad():
  324. running_mean = exponential_average_factor * mean \
  325. + (1 - exponential_average_factor) * self.running_mean
  326. # update running_var with unbiased var
  327. running_var = exponential_average_factor * var * n / (n - 1) \
  328. + (1 - exponential_average_factor) * self.running_var
  329. self.running_mean.copy_(running_mean)
  330. self.running_var.copy_(running_var)
  331. if save_index == 1:
  332. self.meta_mean1.copy_(mean)
  333. self.meta_var1.copy_(var)
  334. elif save_index == 2:
  335. self.meta_mean2.copy_(mean)
  336. self.meta_var2.copy_(var)
  337. else:
  338. mean = self.running_mean
  339. var = self.running_var
  340. input = (input - mean[None, :]) / (torch.sqrt(var[None, :] + self.eps))
  341. if self.affine:
  342. input = input * self.weight[None, :] + self.bias[None, :]
  343. return input