You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reparameterization.ipynb 29 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "d7cbe5ee",
  6. "metadata": {},
  7. "source": [
  8. "# Reparameterization"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "id": "13393b70",
  14. "metadata": {},
  15. "source": [
  16. "## YOLOv7 reparameterization"
  17. ]
  18. },
  19. {
  20. "cell_type": "code",
  21. "execution_count": null,
  22. "id": "bf53becf",
  23. "metadata": {},
  24. "outputs": [],
  25. "source": [
  26. "# import\n",
  27. "from copy import deepcopy\n",
  28. "from models.yolo import Model\n",
  29. "import torch\n",
  30. "from utils.torch_utils import select_device, is_parallel\n",
  31. "\n",
  32. "device = select_device('0', batch_size=1)\n",
  33. "# model trained by cfg/training/*.yaml\n",
  34. "ckpt = torch.load('cfg/training/yolov7.pt', map_location=device)\n",
  35. "# reparameterized model in cfg/deploy/*.yaml\n",
  36. "model = Model('cfg/deploy/yolov7.yaml', ch=3, nc=80).to(device)\n",
  37. "\n",
  38. "# copy intersect weights\n",
  39. "state_dict = ckpt['model'].float().state_dict()\n",
  40. "exclude = []\n",
  41. "intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
  42. "model.load_state_dict(intersect_state_dict, strict=False)\n",
  43. "model.names = ckpt['model'].names\n",
  44. "model.nc = ckpt['model'].nc\n",
  45. "\n",
  46. "# reparametrized YOLOR\n",
  47. "for i in range(255):\n",
  48. " model.state_dict()['model.105.m.0.weight'].data[i, :, :, :] *= state_dict['model.105.im.0.implicit'].data[:, i, : :].squeeze()\n",
  49. " model.state_dict()['model.105.m.1.weight'].data[i, :, :, :] *= state_dict['model.105.im.1.implicit'].data[:, i, : :].squeeze()\n",
  50. " model.state_dict()['model.105.m.2.weight'].data[i, :, :, :] *= state_dict['model.105.im.2.implicit'].data[:, i, : :].squeeze()\n",
  51. "model.state_dict()['model.105.m.0.bias'].data += state_dict['model.105.m.0.weight'].mul(state_dict['model.105.ia.0.implicit']).sum(1).squeeze()\n",
  52. "model.state_dict()['model.105.m.1.bias'].data += state_dict['model.105.m.1.weight'].mul(state_dict['model.105.ia.1.implicit']).sum(1).squeeze()\n",
  53. "model.state_dict()['model.105.m.2.bias'].data += state_dict['model.105.m.2.weight'].mul(state_dict['model.105.ia.2.implicit']).sum(1).squeeze()\n",
  54. "model.state_dict()['model.105.m.0.bias'].data *= state_dict['model.105.im.0.implicit'].data.squeeze()\n",
  55. "model.state_dict()['model.105.m.1.bias'].data *= state_dict['model.105.im.1.implicit'].data.squeeze()\n",
  56. "model.state_dict()['model.105.m.2.bias'].data *= state_dict['model.105.im.2.implicit'].data.squeeze()\n",
  57. "\n",
  58. "# model to be saved\n",
  59. "ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
  60. " 'optimizer': None,\n",
  61. " 'training_results': None,\n",
  62. " 'epoch': -1}\n",
  63. "\n",
  64. "# save reparameterized model\n",
  65. "torch.save(ckpt, 'cfg/deploy/yolov7.pt')\n"
  66. ]
  67. },
  68. {
  69. "cell_type": "markdown",
  70. "id": "5b396a53",
  71. "metadata": {},
  72. "source": [
  73. "## YOLOv7x reparameterization"
  74. ]
  75. },
  76. {
  77. "cell_type": "code",
  78. "execution_count": null,
  79. "id": "9d54d17f",
  80. "metadata": {},
  81. "outputs": [],
  82. "source": [
  83. "# import\n",
  84. "from copy import deepcopy\n",
  85. "from models.yolo import Model\n",
  86. "import torch\n",
  87. "from utils.torch_utils import select_device, is_parallel\n",
  88. "\n",
  89. "device = select_device('0', batch_size=1)\n",
  90. "# model trained by cfg/training/*.yaml\n",
  91. "ckpt = torch.load('cfg/training/yolov7x.pt', map_location=device)\n",
  92. "# reparameterized model in cfg/deploy/*.yaml\n",
  93. "model = Model('cfg/deploy/yolov7x.yaml', ch=3, nc=80).to(device)\n",
  94. "\n",
  95. "# copy intersect weights\n",
  96. "state_dict = ckpt['model'].float().state_dict()\n",
  97. "exclude = []\n",
  98. "intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
  99. "model.load_state_dict(intersect_state_dict, strict=False)\n",
  100. "model.names = ckpt['model'].names\n",
  101. "model.nc = ckpt['model'].nc\n",
  102. "\n",
  103. "# reparametrized YOLOR\n",
  104. "for i in range(255):\n",
  105. " model.state_dict()['model.121.m.0.weight'].data[i, :, :, :] *= state_dict['model.121.im.0.implicit'].data[:, i, : :].squeeze()\n",
  106. " model.state_dict()['model.121.m.1.weight'].data[i, :, :, :] *= state_dict['model.121.im.1.implicit'].data[:, i, : :].squeeze()\n",
  107. " model.state_dict()['model.121.m.2.weight'].data[i, :, :, :] *= state_dict['model.121.im.2.implicit'].data[:, i, : :].squeeze()\n",
  108. "model.state_dict()['model.121.m.0.bias'].data += state_dict['model.121.m.0.weight'].mul(state_dict['model.121.ia.0.implicit']).sum(1).squeeze()\n",
  109. "model.state_dict()['model.121.m.1.bias'].data += state_dict['model.121.m.1.weight'].mul(state_dict['model.121.ia.1.implicit']).sum(1).squeeze()\n",
  110. "model.state_dict()['model.121.m.2.bias'].data += state_dict['model.121.m.2.weight'].mul(state_dict['model.121.ia.2.implicit']).sum(1).squeeze()\n",
  111. "model.state_dict()['model.121.m.0.bias'].data *= state_dict['model.121.im.0.implicit'].data.squeeze()\n",
  112. "model.state_dict()['model.121.m.1.bias'].data *= state_dict['model.121.im.1.implicit'].data.squeeze()\n",
  113. "model.state_dict()['model.121.m.2.bias'].data *= state_dict['model.121.im.2.implicit'].data.squeeze()\n",
  114. "\n",
  115. "# model to be saved\n",
  116. "ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
  117. " 'optimizer': None,\n",
  118. " 'training_results': None,\n",
  119. " 'epoch': -1}\n",
  120. "\n",
  121. "# save reparameterized model\n",
  122. "torch.save(ckpt, 'cfg/deploy/yolov7x.pt')\n"
  123. ]
  124. },
  125. {
  126. "cell_type": "markdown",
  127. "id": "11a9108e",
  128. "metadata": {},
  129. "source": [
  130. "## YOLOv7-W6 reparameterization"
  131. ]
  132. },
  133. {
  134. "cell_type": "code",
  135. "execution_count": null,
  136. "id": "d032c629",
  137. "metadata": {},
  138. "outputs": [],
  139. "source": [
  140. "# import\n",
  141. "from copy import deepcopy\n",
  142. "from models.yolo import Model\n",
  143. "import torch\n",
  144. "from utils.torch_utils import select_device, is_parallel\n",
  145. "\n",
  146. "device = select_device('0', batch_size=1)\n",
  147. "# model trained by cfg/training/*.yaml\n",
  148. "ckpt = torch.load('cfg/training/yolov7-w6.pt', map_location=device)\n",
  149. "# reparameterized model in cfg/deploy/*.yaml\n",
  150. "model = Model('cfg/deploy/yolov7-w6.yaml', ch=3, nc=80).to(device)\n",
  151. "\n",
  152. "# copy intersect weights\n",
  153. "state_dict = ckpt['model'].float().state_dict()\n",
  154. "exclude = []\n",
  155. "intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
  156. "model.load_state_dict(intersect_state_dict, strict=False)\n",
  157. "model.names = ckpt['model'].names\n",
  158. "model.nc = ckpt['model'].nc\n",
  159. "\n",
  160. "idx = 118\n",
  161. "idx2 = 122\n",
  162. "\n",
  163. "# copy weights of lead head\n",
  164. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data -= model.state_dict()['model.{}.m.0.weight'.format(idx)].data\n",
  165. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data -= model.state_dict()['model.{}.m.1.weight'.format(idx)].data\n",
  166. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data -= model.state_dict()['model.{}.m.2.weight'.format(idx)].data\n",
  167. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data -= model.state_dict()['model.{}.m.3.weight'.format(idx)].data\n",
  168. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].data\n",
  169. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].data\n",
  170. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].data\n",
  171. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].data\n",
  172. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data -= model.state_dict()['model.{}.m.0.bias'.format(idx)].data\n",
  173. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data -= model.state_dict()['model.{}.m.1.bias'.format(idx)].data\n",
  174. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data -= model.state_dict()['model.{}.m.2.bias'.format(idx)].data\n",
  175. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data -= model.state_dict()['model.{}.m.3.bias'.format(idx)].data\n",
  176. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.bias'.format(idx2)].data\n",
  177. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.bias'.format(idx2)].data\n",
  178. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.bias'.format(idx2)].data\n",
  179. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
  180. "\n",
  181. "# reparametrized YOLOR\n",
  182. "for i in range(255):\n",
  183. " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  184. " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  185. " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  186. " model.state_dict()['model.{}.m.3.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.3.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  187. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].mul(state_dict['model.{}.ia.0.implicit'.format(idx2)]).sum(1).squeeze()\n",
  188. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].mul(state_dict['model.{}.ia.1.implicit'.format(idx2)]).sum(1).squeeze()\n",
  189. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].mul(state_dict['model.{}.ia.2.implicit'.format(idx2)]).sum(1).squeeze()\n",
  190. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].mul(state_dict['model.{}.ia.3.implicit'.format(idx2)]).sum(1).squeeze()\n",
  191. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data *= state_dict['model.{}.im.0.implicit'.format(idx2)].data.squeeze()\n",
  192. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data *= state_dict['model.{}.im.1.implicit'.format(idx2)].data.squeeze()\n",
  193. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data *= state_dict['model.{}.im.2.implicit'.format(idx2)].data.squeeze()\n",
  194. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data *= state_dict['model.{}.im.3.implicit'.format(idx2)].data.squeeze()\n",
  195. "\n",
  196. "# model to be saved\n",
  197. "ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
  198. " 'optimizer': None,\n",
  199. " 'training_results': None,\n",
  200. " 'epoch': -1}\n",
  201. "\n",
  202. "# save reparameterized model\n",
  203. "torch.save(ckpt, 'cfg/deploy/yolov7-w6.pt')\n"
  204. ]
  205. },
  206. {
  207. "cell_type": "markdown",
  208. "id": "5f093d43",
  209. "metadata": {},
  210. "source": [
  211. "## YOLOv7-E6 reparameterization"
  212. ]
  213. },
  214. {
  215. "cell_type": "code",
  216. "execution_count": null,
  217. "id": "aa2b2142",
  218. "metadata": {},
  219. "outputs": [],
  220. "source": [
  221. "# import\n",
  222. "from copy import deepcopy\n",
  223. "from models.yolo import Model\n",
  224. "import torch\n",
  225. "from utils.torch_utils import select_device, is_parallel\n",
  226. "\n",
  227. "device = select_device('0', batch_size=1)\n",
  228. "# model trained by cfg/training/*.yaml\n",
  229. "ckpt = torch.load('cfg/training/yolov7-e6.pt', map_location=device)\n",
  230. "# reparameterized model in cfg/deploy/*.yaml\n",
  231. "model = Model('cfg/deploy/yolov7-e6.yaml', ch=3, nc=80).to(device)\n",
  232. "\n",
  233. "# copy intersect weights\n",
  234. "state_dict = ckpt['model'].float().state_dict()\n",
  235. "exclude = []\n",
  236. "intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
  237. "model.load_state_dict(intersect_state_dict, strict=False)\n",
  238. "model.names = ckpt['model'].names\n",
  239. "model.nc = ckpt['model'].nc\n",
  240. "\n",
  241. "idx = 140\n",
  242. "idx2 = 144\n",
  243. "\n",
  244. "# copy weights of lead head\n",
  245. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data -= model.state_dict()['model.{}.m.0.weight'.format(idx)].data\n",
  246. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data -= model.state_dict()['model.{}.m.1.weight'.format(idx)].data\n",
  247. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data -= model.state_dict()['model.{}.m.2.weight'.format(idx)].data\n",
  248. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data -= model.state_dict()['model.{}.m.3.weight'.format(idx)].data\n",
  249. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].data\n",
  250. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].data\n",
  251. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].data\n",
  252. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].data\n",
  253. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data -= model.state_dict()['model.{}.m.0.bias'.format(idx)].data\n",
  254. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data -= model.state_dict()['model.{}.m.1.bias'.format(idx)].data\n",
  255. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data -= model.state_dict()['model.{}.m.2.bias'.format(idx)].data\n",
  256. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data -= model.state_dict()['model.{}.m.3.bias'.format(idx)].data\n",
  257. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.bias'.format(idx2)].data\n",
  258. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.bias'.format(idx2)].data\n",
  259. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.bias'.format(idx2)].data\n",
  260. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
  261. "\n",
  262. "# reparametrized YOLOR\n",
  263. "for i in range(255):\n",
  264. " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  265. " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  266. " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  267. " model.state_dict()['model.{}.m.3.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.3.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  268. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].mul(state_dict['model.{}.ia.0.implicit'.format(idx2)]).sum(1).squeeze()\n",
  269. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].mul(state_dict['model.{}.ia.1.implicit'.format(idx2)]).sum(1).squeeze()\n",
  270. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].mul(state_dict['model.{}.ia.2.implicit'.format(idx2)]).sum(1).squeeze()\n",
  271. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].mul(state_dict['model.{}.ia.3.implicit'.format(idx2)]).sum(1).squeeze()\n",
  272. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data *= state_dict['model.{}.im.0.implicit'.format(idx2)].data.squeeze()\n",
  273. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data *= state_dict['model.{}.im.1.implicit'.format(idx2)].data.squeeze()\n",
  274. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data *= state_dict['model.{}.im.2.implicit'.format(idx2)].data.squeeze()\n",
  275. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data *= state_dict['model.{}.im.3.implicit'.format(idx2)].data.squeeze()\n",
  276. "\n",
  277. "# model to be saved\n",
  278. "ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
  279. " 'optimizer': None,\n",
  280. " 'training_results': None,\n",
  281. " 'epoch': -1}\n",
  282. "\n",
  283. "# save reparameterized model\n",
  284. "torch.save(ckpt, 'cfg/deploy/yolov7-e6.pt')\n"
  285. ]
  286. },
  287. {
  288. "cell_type": "markdown",
  289. "id": "a3bccf89",
  290. "metadata": {},
  291. "source": [
  292. "## YOLOv7-D6 reparameterization"
  293. ]
  294. },
  295. {
  296. "cell_type": "code",
  297. "execution_count": null,
  298. "id": "e5216b70",
  299. "metadata": {},
  300. "outputs": [],
  301. "source": [
  302. "# import\n",
  303. "from copy import deepcopy\n",
  304. "from models.yolo import Model\n",
  305. "import torch\n",
  306. "from utils.torch_utils import select_device, is_parallel\n",
  307. "\n",
  308. "device = select_device('0', batch_size=1)\n",
  309. "# model trained by cfg/training/*.yaml\n",
  310. "ckpt = torch.load('cfg/training/yolov7-d6.pt', map_location=device)\n",
  311. "# reparameterized model in cfg/deploy/*.yaml\n",
  312. "model = Model('cfg/deploy/yolov7-d6.yaml', ch=3, nc=80).to(device)\n",
  313. "\n",
  314. "# copy intersect weights\n",
  315. "state_dict = ckpt['model'].float().state_dict()\n",
  316. "exclude = []\n",
  317. "intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
  318. "model.load_state_dict(intersect_state_dict, strict=False)\n",
  319. "model.names = ckpt['model'].names\n",
  320. "model.nc = ckpt['model'].nc\n",
  321. "\n",
  322. "idx = 162\n",
  323. "idx2 = 166\n",
  324. "\n",
  325. "# copy weights of lead head\n",
  326. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data -= model.state_dict()['model.{}.m.0.weight'.format(idx)].data\n",
  327. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data -= model.state_dict()['model.{}.m.1.weight'.format(idx)].data\n",
  328. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data -= model.state_dict()['model.{}.m.2.weight'.format(idx)].data\n",
  329. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data -= model.state_dict()['model.{}.m.3.weight'.format(idx)].data\n",
  330. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].data\n",
  331. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].data\n",
  332. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].data\n",
  333. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].data\n",
  334. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data -= model.state_dict()['model.{}.m.0.bias'.format(idx)].data\n",
  335. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data -= model.state_dict()['model.{}.m.1.bias'.format(idx)].data\n",
  336. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data -= model.state_dict()['model.{}.m.2.bias'.format(idx)].data\n",
  337. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data -= model.state_dict()['model.{}.m.3.bias'.format(idx)].data\n",
  338. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.bias'.format(idx2)].data\n",
  339. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.bias'.format(idx2)].data\n",
  340. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.bias'.format(idx2)].data\n",
  341. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
  342. "\n",
  343. "# reparametrized YOLOR\n",
  344. "for i in range(255):\n",
  345. " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  346. " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  347. " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  348. " model.state_dict()['model.{}.m.3.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.3.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  349. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].mul(state_dict['model.{}.ia.0.implicit'.format(idx2)]).sum(1).squeeze()\n",
  350. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].mul(state_dict['model.{}.ia.1.implicit'.format(idx2)]).sum(1).squeeze()\n",
  351. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].mul(state_dict['model.{}.ia.2.implicit'.format(idx2)]).sum(1).squeeze()\n",
  352. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].mul(state_dict['model.{}.ia.3.implicit'.format(idx2)]).sum(1).squeeze()\n",
  353. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data *= state_dict['model.{}.im.0.implicit'.format(idx2)].data.squeeze()\n",
  354. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data *= state_dict['model.{}.im.1.implicit'.format(idx2)].data.squeeze()\n",
  355. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data *= state_dict['model.{}.im.2.implicit'.format(idx2)].data.squeeze()\n",
  356. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data *= state_dict['model.{}.im.3.implicit'.format(idx2)].data.squeeze()\n",
  357. "\n",
  358. "# model to be saved\n",
  359. "ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
  360. " 'optimizer': None,\n",
  361. " 'training_results': None,\n",
  362. " 'epoch': -1}\n",
  363. "\n",
  364. "# save reparameterized model\n",
  365. "torch.save(ckpt, 'cfg/deploy/yolov7-d6.pt')\n"
  366. ]
  367. },
  368. {
  369. "cell_type": "markdown",
  370. "id": "334c273b",
  371. "metadata": {},
  372. "source": [
  373. "## YOLOv7-E6E reparameterization"
  374. ]
  375. },
  376. {
  377. "cell_type": "code",
  378. "execution_count": null,
  379. "id": "635fd8d2",
  380. "metadata": {},
  381. "outputs": [],
  382. "source": [
  383. "# import\n",
  384. "from copy import deepcopy\n",
  385. "from models.yolo import Model\n",
  386. "import torch\n",
  387. "from utils.torch_utils import select_device, is_parallel\n",
  388. "\n",
  389. "device = select_device('0', batch_size=1)\n",
  390. "# model trained by cfg/training/*.yaml\n",
  391. "ckpt = torch.load('cfg/training/yolov7-e6e.pt', map_location=device)\n",
  392. "# reparameterized model in cfg/deploy/*.yaml\n",
  393. "model = Model('cfg/deploy/yolov7-e6e.yaml', ch=3, nc=80).to(device)\n",
  394. "\n",
  395. "# copy intersect weights\n",
  396. "state_dict = ckpt['model'].float().state_dict()\n",
  397. "exclude = []\n",
  398. "intersect_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[k].shape}\n",
  399. "model.load_state_dict(intersect_state_dict, strict=False)\n",
  400. "model.names = ckpt['model'].names\n",
  401. "model.nc = ckpt['model'].nc\n",
  402. "\n",
  403. "idx = 261\n",
  404. "idx2 = 265\n",
  405. "\n",
  406. "# copy weights of lead head\n",
  407. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data -= model.state_dict()['model.{}.m.0.weight'.format(idx)].data\n",
  408. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data -= model.state_dict()['model.{}.m.1.weight'.format(idx)].data\n",
  409. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data -= model.state_dict()['model.{}.m.2.weight'.format(idx)].data\n",
  410. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data -= model.state_dict()['model.{}.m.3.weight'.format(idx)].data\n",
  411. "model.state_dict()['model.{}.m.0.weight'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].data\n",
  412. "model.state_dict()['model.{}.m.1.weight'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].data\n",
  413. "model.state_dict()['model.{}.m.2.weight'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].data\n",
  414. "model.state_dict()['model.{}.m.3.weight'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].data\n",
  415. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data -= model.state_dict()['model.{}.m.0.bias'.format(idx)].data\n",
  416. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data -= model.state_dict()['model.{}.m.1.bias'.format(idx)].data\n",
  417. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data -= model.state_dict()['model.{}.m.2.bias'.format(idx)].data\n",
  418. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data -= model.state_dict()['model.{}.m.3.bias'.format(idx)].data\n",
  419. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.bias'.format(idx2)].data\n",
  420. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.bias'.format(idx2)].data\n",
  421. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.bias'.format(idx2)].data\n",
  422. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
  423. "\n",
  424. "# reparametrized YOLOR\n",
  425. "for i in range(255):\n",
  426. " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  427. " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  428. " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  429. " model.state_dict()['model.{}.m.3.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.3.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
  430. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data += state_dict['model.{}.m.0.weight'.format(idx2)].mul(state_dict['model.{}.ia.0.implicit'.format(idx2)]).sum(1).squeeze()\n",
  431. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data += state_dict['model.{}.m.1.weight'.format(idx2)].mul(state_dict['model.{}.ia.1.implicit'.format(idx2)]).sum(1).squeeze()\n",
  432. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data += state_dict['model.{}.m.2.weight'.format(idx2)].mul(state_dict['model.{}.ia.2.implicit'.format(idx2)]).sum(1).squeeze()\n",
  433. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.weight'.format(idx2)].mul(state_dict['model.{}.ia.3.implicit'.format(idx2)]).sum(1).squeeze()\n",
  434. "model.state_dict()['model.{}.m.0.bias'.format(idx)].data *= state_dict['model.{}.im.0.implicit'.format(idx2)].data.squeeze()\n",
  435. "model.state_dict()['model.{}.m.1.bias'.format(idx)].data *= state_dict['model.{}.im.1.implicit'.format(idx2)].data.squeeze()\n",
  436. "model.state_dict()['model.{}.m.2.bias'.format(idx)].data *= state_dict['model.{}.im.2.implicit'.format(idx2)].data.squeeze()\n",
  437. "model.state_dict()['model.{}.m.3.bias'.format(idx)].data *= state_dict['model.{}.im.3.implicit'.format(idx2)].data.squeeze()\n",
  438. "\n",
  439. "# model to be saved\n",
  440. "ckpt = {'model': deepcopy(model.module if is_parallel(model) else model).half(),\n",
  441. " 'optimizer': None,\n",
  442. " 'training_results': None,\n",
  443. " 'epoch': -1}\n",
  444. "\n",
  445. "# save reparameterized model\n",
  446. "torch.save(ckpt, 'cfg/deploy/yolov7-e6e.pt')\n"
  447. ]
  448. },
  449. {
  450. "cell_type": "code",
  451. "execution_count": null,
  452. "id": "63a62625",
  453. "metadata": {},
  454. "outputs": [],
  455. "source": []
  456. }
  457. ],
  458. "metadata": {
  459. "kernelspec": {
  460. "display_name": "Python 3 (ipykernel)",
  461. "language": "python",
  462. "name": "python3"
  463. },
  464. "language_info": {
  465. "codemirror_mode": {
  466. "name": "ipython",
  467. "version": 3
  468. },
  469. "file_extension": ".py",
  470. "mimetype": "text/x-python",
  471. "name": "python",
  472. "nbconvert_exporter": "python",
  473. "pygments_lexer": "ipython3",
  474. "version": "3.8.10"
  475. }
  476. },
  477. "nbformat": 4,
  478. "nbformat_minor": 5
  479. }

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。