You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optim.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. # ***************************************************************
  2. # Copyright (c) 2023 Jittor. All Rights Reserved.
  3. # Maintainers:
  4. # Guowei Yang <471184555@qq.com>
  5. # Guoye Yang <498731903@qq.com>
  6. # Wenyang Zhou <576825820@qq.com>
  7. # Meng-Hao Guo <guomenghao1997@gmail.com>
  8. # Dun Liang <randonlang@gmail.com>.
  9. #
  10. #
  11. # This file is subject to the terms and conditions defined in
  12. # file 'LICENSE.txt', which is part of this source code package.
  13. # ***************************************************************
  14. import jittor as jt
  15. import numpy as np
  16. class Optimizer(object):
  17. """ Basic class of Optimizer.
  18. Example::
  19. optimizer = nn.SGD(model.parameters(), lr)
  20. optimizer.step(loss)
  21. """
  22. def __init__(self, params, lr, param_sync_iter=10000):
  23. self.param_groups = []
  24. self.lr = lr
  25. self.param_sync_iter = param_sync_iter
  26. assert len(params) > 0, "Length of parameters should not be zero"
  27. if not isinstance(params[0], dict):
  28. params = [{'params': params}]
  29. for pg in params:
  30. assert isinstance(pg, dict)
  31. self.param_groups.append(pg)
  32. self.n_step = 0
  33. # __zero_grad is a value for fast determ the grad is zero or not
  34. # so we can omit 0+x
  35. self.__zero_grad = True
  36. self._grad_map = {}
  37. def add_param_group(self, group):
  38. self.param_groups.append(group)
  39. def clip_grad_norm(self, max_norm:float, norm_type:int=2):
  40. r"""Clips gradient norm of this optimizer.
  41. The norm is computed over all gradients together.
  42. Args:
  43. max_norm (float or int): max norm of the gradients
  44. norm_type (int): 1-norm or 2-norm
  45. Example::
  46. a = jt.ones(2)
  47. opt = jt.optim.SGD([a], 0.1)
  48. loss = a*a
  49. opt.zero_grad()
  50. opt.backward(loss)
  51. print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83
  52. opt.clip_grad_norm(0.01, 2)
  53. print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01
  54. opt.step()
  55. """
  56. if self.__zero_grad: return
  57. grads = []
  58. for pg in self.param_groups:
  59. for p, g in zip(pg["params"], pg["grads"]):
  60. if p.is_stop_grad(): continue
  61. grads.append(g.flatten())
  62. if len(grads) == 0: return
  63. total_norm = jt.norm(jt.concat(grads), norm_type)
  64. clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0)
  65. for pg in self.param_groups:
  66. for p, g in zip(pg["params"], pg["grads"]):
  67. if p.is_stop_grad(): continue
  68. g.update(g*clip_coef)
  69. @property
  70. def defaults(self):
  71. exclude = set(("defaults", "pre_step", "step"))
  72. return { k:v for k, v in self.__dict__.items()
  73. if k[0] != '_' and k not in exclude and not callable(v) }
  74. def state_dict(self):
  75. state = {"defaults": self.defaults}
  76. return state
  77. def load_state_dict(self, state):
  78. def dfs(x):
  79. if isinstance(x, list):
  80. for i in range(len(x)):
  81. x[i] = dfs(x[i])
  82. elif isinstance(x, dict):
  83. for k in x:
  84. x[k] = dfs(x[k])
  85. elif isinstance(x, np.ndarray):
  86. return jt.array(x).stop_grad()
  87. elif isinstance(x, jt.Var):
  88. return x.stop_grad()
  89. return x
  90. exclude = set(("param_groups", "params"))
  91. for k, v in state["defaults"].items():
  92. if k not in exclude:
  93. setattr(self, k, dfs(v))
  94. param_groups = dfs(state["defaults"].get('param_groups', None))
  95. if param_groups is not None:
  96. exclude = set(("params",))
  97. for i in range(len(param_groups)):
  98. for k, v in param_groups[i].items():
  99. if k not in exclude:
  100. self.param_groups[i][k] = v
  101. def zero_grad(self):
  102. self.__zero_grad = True
  103. def backward(self, loss, retain_graph=False):
  104. '''
  105. optimize.backward(loss) is used for accumulate multiple step,
  106. it can be used as following:
  107. Origin source code ::
  108. n_iter = 10000
  109. batch_size = 100
  110. ...
  111. for i in range(n_iter):
  112. ...
  113. loss = calc_loss()
  114. optimizer.step(loss)
  115. Accumulation version ::
  116. n_iter = 10000
  117. batch_size = 100
  118. accumulation_steps = 10
  119. n_iter *= accumulation_steps
  120. batch_size //= accumulation_steps
  121. ...
  122. for i in range(n_iter):
  123. ...
  124. loss = calc_loss()
  125. # if loss is a mean across batch, we need to divide accumulation_steps
  126. optimizer.backward(loss / accumulation_steps)
  127. if (i+1) % accumulation_steps == 0:
  128. optimizer.step()
  129. '''
  130. # clean prev grads
  131. params = []
  132. params_has_grad = []
  133. for pg in self.param_groups:
  134. for p in pg['params']:
  135. params.append(p)
  136. if not p.is_stop_grad():
  137. params_has_grad.append(p)
  138. # sync prev params
  139. jt.sync(params_has_grad)
  140. # get gradient
  141. grads = jt.grad(loss, params_has_grad, retain_graph)
  142. # sync grads and model if in mpi
  143. if jt.in_mpi:
  144. dep = []
  145. def add_dep(v):
  146. nonlocal dep
  147. v._add_dependency(dep)
  148. dep = [v]
  149. for g in grads:
  150. g.assign(g.mpi_all_reduce("mean"))
  151. add_dep(g._input(0))
  152. if self.n_step % self.param_sync_iter == 0:
  153. for p in params:
  154. p.assign(p.mpi_broadcast())
  155. add_dep(p)
  156. self.n_step += 1
  157. # set up grads in param_groups
  158. pid = 0
  159. for pg in self.param_groups:
  160. if "grads" not in pg:
  161. pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ]
  162. pg_grads = pg["grads"]
  163. for i, p in enumerate(pg['params']):
  164. if not p.is_stop_grad():
  165. # accumulate grad and stop grad of grad
  166. g = grads[pid].stop_grad()
  167. if not self.__zero_grad:
  168. g = g + pg_grads[i]
  169. pg_grads[i].update(g)
  170. pid += 1
  171. self.__zero_grad = False
  172. def pre_step(self, loss, retain_graph=False):
  173. """ something should be done before step, such as calc gradients, mpi sync, and so on.
  174. Example::
  175. class MyOptimizer(Optimizer):
  176. def step(self, loss):
  177. self.pre_step(loss)
  178. ...
  179. self.post_step()
  180. """
  181. if loss is not None:
  182. self.backward(loss, retain_graph)
  183. jt.flags.node_order = 1
  184. def post_step(self):
  185. """ something should be done before step, such as zero grad, and so on.
  186. Example::
  187. class MyOptimizer(Optimizer):
  188. def step(self, loss):
  189. self.pre_step(loss)
  190. ...
  191. self.post_step()
  192. """
  193. jt.flags.node_order = 0
  194. self.zero_grad()
  195. def step(self, loss=None, retain_graph=False):
  196. self.pre_step(loss, retain_graph)
  197. for pg in self.param_groups:
  198. lr = pg.get("lr", self.lr)
  199. for p, g in zip(pg["params"], pg["grads"]):
  200. if p.is_stop_grad(): continue
  201. p.update(p - g * lr)
  202. self.post_step()
  203. def _build_grad_map(self):
  204. _grad_map = {}
  205. for pg in self.param_groups:
  206. for p, g in zip(pg["params"], pg["grads"]):
  207. _grad_map[id(p)] = g
  208. self._grad_map = _grad_map
  209. def find_grad(self, v:jt.Var) -> jt.Var:
  210. if id(v) not in self._grad_map:
  211. self._build_grad_map()
  212. if id(v) not in self._grad_map:
  213. raise RuntimeError("This variable is not managed by this optimizer")
  214. return self._grad_map[id(v)]
  215. def opt_grad(v:jt.Var, opt:Optimizer):
  216. ''' Get grad of certain variable in optimizer, Example::
  217. model = Model()
  218. optimizer = SGD(model.parameters())
  219. ...
  220. optimizer.backward(loss)
  221. for p in model.parameters():
  222. grad = p.opt_grad(optimizer)
  223. '''
  224. return opt.find_grad(v)
  225. jt.Var.opt_grad = opt_grad
  226. class SGD(Optimizer):
  227. """ SGD Optimizer.
  228. Example::
  229. optimizer = nn.SGD(model.parameters(), lr, momentum=0.9)
  230. optimizer.step(loss)
  231. """
  232. def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
  233. super().__init__(params, lr)
  234. self.momentum = momentum
  235. self.weight_decay = weight_decay
  236. self.dampening = dampening
  237. self.nesterov = nesterov
  238. # initialize required arguments
  239. for pg in self.param_groups:
  240. values = pg["values"] = []
  241. for p in pg["params"]:
  242. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  243. def add_param_group(self, group):
  244. values = group["values"] = []
  245. for p in group["params"]:
  246. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  247. self.param_groups.append(group)
  248. def step(self, loss=None, retain_graph=False):
  249. self.pre_step(loss, retain_graph=False)
  250. jt.flags.node_order = 1
  251. for pg in self.param_groups:
  252. # get arguments from each param_groups
  253. lr = pg.get("lr", self.lr)
  254. momentum = pg.get("momentum", self.momentum)
  255. weight_decay = pg.get("weight_decay", self.weight_decay)
  256. dampening = pg.get("dampening", self.dampening)
  257. nesterov = pg.get("nesterov", self.nesterov)
  258. # optimize main body
  259. for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
  260. if p.is_stop_grad(): continue
  261. dp = p * weight_decay + g
  262. v.update(momentum * v + dp * (1 - dampening))
  263. if nesterov:
  264. p.update(p - (dp + momentum * v) * lr)
  265. else:
  266. p.update(p - v * lr)
  267. self.post_step()
  268. class RMSprop(Optimizer):
  269. """ RMSprop Optimizer.
  270. Args:
  271. params(list): parameters of model.
  272. lr(float): learning rate.
  273. eps(float): term added to the denominator to avoid division by zero, default 1e-8.
  274. alpha(float): smoothing constant, default 0.99.
  275. Example:
  276. optimizer = nn.RMSprop(model.parameters(), lr)
  277. optimizer.step(loss)
  278. """
  279. def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99):
  280. super().__init__(params, lr)
  281. self.eps = eps
  282. self.alpha = alpha
  283. # initialize required arguments for each param_groups
  284. for pg in self.param_groups:
  285. values = pg["values"] = []
  286. for p in pg["params"]:
  287. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  288. def add_param_group(self, group):
  289. values = group["values"] = []
  290. for p in group["params"]:
  291. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  292. self.param_groups.append(group)
  293. def step(self, loss=None, retain_graph=False):
  294. self.pre_step(loss, retain_graph)
  295. for pg in self.param_groups:
  296. # get arguments from each param_groups
  297. lr = pg.get("lr", self.lr)
  298. eps = pg.get("eps", self.eps)
  299. alpha = pg.get("alpha", self.alpha)
  300. for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
  301. if p.is_stop_grad(): continue
  302. v.update(alpha * v + (1-alpha) * g * g)
  303. p.update(p - lr * g / (jt.sqrt(v) + eps))
  304. self.post_step()
  305. class Adam(Optimizer):
  306. """ Adam Optimizer.
  307. Example::
  308. optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
  309. optimizer.step(loss)
  310. """
  311. def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
  312. super().__init__(params, lr)
  313. self.eps = eps
  314. self.betas = betas
  315. self.weight_decay = weight_decay
  316. # assert weight_decay==0, "weight_decay is not supported yet"
  317. # initialize required arguments for each param_groups
  318. for pg in self.param_groups:
  319. values = pg["values"] = []
  320. m = pg["m"] = []
  321. for p in pg["params"]:
  322. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  323. m.append(jt.zeros(p.shape, p.dtype).stop_grad())
  324. def add_param_group(self, group):
  325. values = group["values"] = []
  326. m = group["m"] = []
  327. for p in group["params"]:
  328. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  329. m.append(jt.zeros(p.shape, p.dtype).stop_grad())
  330. self.param_groups.append(group)
  331. def step(self, loss=None, retain_graph=False):
  332. self.pre_step(loss, retain_graph)
  333. n = float(self.n_step)
  334. jt.flags.node_order = 1
  335. for pg in self.param_groups:
  336. # get arguments from each param_groups
  337. lr = pg.get("lr", self.lr)
  338. eps = pg.get("eps", self.eps)
  339. weight_decay = pg.get("weight_decay", self.weight_decay)
  340. b0, b1 = pg.get("betas", self.betas)
  341. for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
  342. if p.is_stop_grad(): continue
  343. g = p * weight_decay + g
  344. m.update(b0 * m + (1-b0) * g)
  345. v.update(b1 * v + (1-b1) * g * g)
  346. step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
  347. p.update(p - m * step_size / (jt.sqrt(v) + eps))
  348. self.post_step()
  349. class AdamW(Optimizer):
  350. """ AdamW Optimizer.
  351. Example::
  352. optimizer = nn.AdamW(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
  353. optimizer.step(loss)
  354. """
  355. def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
  356. super().__init__(params, lr)
  357. self.eps = eps
  358. self.betas = betas
  359. self.weight_decay = weight_decay
  360. # assert weight_decay==0, "weight_decay is not supported yet"
  361. # initialize required arguments for each param_groups
  362. for pg in self.param_groups:
  363. values = pg["values"] = []
  364. m = pg["m"] = []
  365. for p in pg["params"]:
  366. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  367. m.append(jt.zeros(p.shape, p.dtype).stop_grad())
  368. def add_param_group(self, group):
  369. values = group["values"] = []
  370. m = group["m"] = []
  371. for p in group["params"]:
  372. values.append(jt.zeros(p.shape, p.dtype).stop_grad())
  373. m.append(jt.zeros(p.shape, p.dtype).stop_grad())
  374. self.param_groups.append(group)
  375. def step(self, loss=None, retain_graph=False):
  376. self.pre_step(loss, retain_graph)
  377. n = float(self.n_step)
  378. for pg in self.param_groups:
  379. # get arguments from each param_groups
  380. lr = pg.get("lr", self.lr)
  381. eps = pg.get("eps", self.eps)
  382. weight_decay = pg.get("weight_decay", self.weight_decay)
  383. b0, b1 = pg.get("betas", self.betas)
  384. for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
  385. if p.is_stop_grad(): continue
  386. p.update(p * (1 - lr * weight_decay))
  387. bias_correction1 = 1 - b0 ** n
  388. bias_correction2 = 1 - b1 ** n
  389. m.update(b0 * m + (1-b0) * g) #exp_avg
  390. v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq
  391. denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps
  392. step_size = lr / bias_correction1
  393. p.update(p - step_size * m / denom)
  394. self.post_step()
  395. class Adan(Optimizer):
  396. """ Adan Optimizer.
  397. Adan was proposed in
  398. Adan: Adaptive Nesterov Momentum Algorithm for
  399. Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022.
  400. https://arxiv.org/abs/2208.06677
  401. Adan is an efficient optimizer for most DNN frameworks:
  402. - About 2x fewer computational load than SOTAs
  403. - Robust to training setting and batch size
  404. - Easy to Plug-and-play
  405. Arguments:
  406. params (iterable): iterable of parameters to optimize or
  407. dicts defining parameter groups.
  408. lr (float, optional): learning rate. (default: 1e-3)
  409. betas (Tuple[float, float, flot], optional): coefficients used for
  410. first- and second-order moments. (default: (0.98, 0.92, 0.99))
  411. eps (float, optional): term added to the denominator to improve
  412. numerical stability. (default: 1e-8)
  413. weight_decay (float, optional): decoupled weight decay
  414. (L2 penalty) (default: 0)
  415. max_grad_norm (float, optional): value used to clip
  416. global grad norm (default: 0.0 no clip)
  417. """
  418. def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99),
  419. eps=1e-8, weight_decay=0.0, max_grad_norm=0.0):
  420. super().__init__(params, lr)
  421. self.betas = betas
  422. self.eps = eps
  423. self.weight_decay = weight_decay
  424. self.max_grad_norm = max_grad_norm
  425. for pg in self.param_groups:
  426. pg["m"] = []
  427. pg["v"] = []
  428. pg["d"] = []
  429. pg["pre_grad"] = []
  430. for p in pg["params"]:
  431. pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  432. pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  433. pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  434. pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  435. def add_param_group(self, group):
  436. group["m"] = []
  437. group["v"] = []
  438. group["d"] = []
  439. group["pre_grad"] = []
  440. for p in group["params"]:
  441. group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  442. group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  443. group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  444. group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  445. self.param_groups.append(group)
  446. def step(self, loss=None, retain_graph=False):
  447. self.pre_step(loss, retain_graph)
  448. n = float(self.n_step)
  449. for pg in self.param_groups:
  450. lr = pg.get("lr", self.lr)
  451. betas = pg.get("betas", self.betas)
  452. eps = pg.get("eps", self.eps)
  453. weight_decay = pg.get("weight_decay", self.weight_decay)
  454. max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm)
  455. if max_grad_norm>0: self.clip_grad_norm(max_grad_norm)
  456. beta1, beta2, beta3 = betas
  457. bias_correction1 = 1 - beta1 ** n
  458. bias_correction2 = 1 - beta2 ** n
  459. bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n)
  460. step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2
  461. step_size = lr * bias_correction3_sqrt / bias_correction1
  462. eps_bias_sqrt = eps * bias_correction3_sqrt
  463. for p, g, m, v, d, pre_g in zip(pg["params"],
  464. pg["grads"],
  465. pg["m"],
  466. pg["v"],
  467. pg["d"],
  468. pg["pre_grad"]):
  469. if p.is_stop_grad(): continue
  470. if self.n_step>0:
  471. pre_g.update(g - pre_g) # Update pre_g as grad_diff
  472. m.update(beta1 * m + (1 - beta1) * g)
  473. d.update(beta2 * d + (1 - beta2) * pre_g) # Use pre_g as grad_diff
  474. pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff)
  475. v.update(beta3 * v + (1 - beta3) * pre_g * pre_g) # Use pre_g as update
  476. p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt))
  477. p.update(p / (1 + lr * weight_decay))
  478. pre_g.update(g) # Update pre_g for the next iteration
  479. self.post_step()
  480. class AdanBelief(Optimizer):
  481. """ Adan Optimizer.
  482. Adan was proposed in
  483. Adan: Adaptive Nesterov Momentum Algorithm for
  484. Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022.
  485. https://arxiv.org/abs/2208.06677
  486. Adan is an efficient optimizer for most DNN frameworks:
  487. - About 2x fewer computational load than SOTAs
  488. - Robust to training setting and batch size
  489. - Easy to Plug-and-play
  490. Arguments:
  491. params (iterable): iterable of parameters to optimize or
  492. dicts defining parameter groups.
  493. lr (float, optional): learning rate. (default: 1e-3)
  494. betas (Tuple[float, float, flot], optional): coefficients used for
  495. first- and second-order moments. (default: (0.98, 0.92, 0.99))
  496. eps (float, optional): term added to the denominator to improve
  497. numerical stability. (default: 1e-8)
  498. weight_decay (float, optional): decoupled weight decay
  499. (L2 penalty) (default: 0)
  500. max_grad_norm (float, optional): value used to clip
  501. global grad norm (default: 0.0 no clip)
  502. """
  503. def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99),
  504. eps=1e-8, weight_decay=0.0, max_grad_norm=0.0):
  505. super().__init__(params, lr)
  506. self.betas = betas
  507. self.eps = eps
  508. self.weight_decay = weight_decay
  509. self.max_grad_norm = max_grad_norm
  510. for pg in self.param_groups:
  511. pg["m"] = []
  512. pg["v"] = []
  513. pg["d"] = []
  514. pg["pre_grad"] = []
  515. for p in pg["params"]:
  516. pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  517. pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  518. pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  519. pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  520. def add_param_group(self, group):
  521. group["m"] = []
  522. group["v"] = []
  523. group["d"] = []
  524. group["pre_grad"] = []
  525. for p in group["params"]:
  526. group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  527. group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  528. group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  529. group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad())
  530. self.param_groups.append(group)
  531. def step(self, loss=None, retain_graph=False):
  532. self.pre_step(loss, retain_graph)
  533. n = float(self.n_step)
  534. for pg in self.param_groups:
  535. lr = pg.get("lr", self.lr)
  536. betas = pg.get("betas", self.betas)
  537. eps = pg.get("eps", self.eps)
  538. weight_decay = pg.get("weight_decay", self.weight_decay)
  539. max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm)
  540. if max_grad_norm>0: self.clip_grad_norm(max_grad_norm)
  541. beta1, beta2, beta3 = betas
  542. bias_correction1 = 1 - beta1 ** n
  543. bias_correction2 = 1 - beta2 ** n
  544. bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n)
  545. step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2
  546. step_size = lr * bias_correction3_sqrt / bias_correction1
  547. eps_bias_sqrt = eps * bias_correction3_sqrt
  548. for p, g, m, v, d, pre_g in zip(pg["params"],
  549. pg["grads"],
  550. pg["m"],
  551. pg["v"],
  552. pg["d"],
  553. pg["pre_grad"]):
  554. if p.is_stop_grad(): continue
  555. if self.n_step>0:
  556. pre_g.update(g - pre_g) # Update pre_g as grad_diff
  557. m.update(beta1 * m + (1 - beta1) * g)
  558. d.update(beta2 * d + (1 - beta2) * pre_g) # # Use belief as update
  559. pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff)
  560. v.update(beta3 * v + (1 - beta3) * (pre_g - m) * (pre_g - m)) # Use pre_g as update
  561. p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt))
  562. p.update(p / (1 + lr * weight_decay)) #AdanBelief best result 0.7358(300 epoch basic)
  563. pre_g.update(g) # Update pre_g for the next iteration
  564. self.post_step()
  565. class LRScheduler:
  566. def __init__(self,optimizer, last_epoch=-1):
  567. assert isinstance(optimizer,Optimizer)
  568. self.optimizer = optimizer
  569. if last_epoch==-1:
  570. for gp in optimizer.param_groups:
  571. gp.setdefault('initial_lr',gp.get('lr',optimizer.lr))
  572. else:
  573. for gp in optimizer.param_groups:
  574. assert 'initial_lr' in gp
  575. self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
  576. self.last_epoch = last_epoch
  577. self.optimizer._step_count = 0
  578. self._step_count = 0
  579. self.step()
  580. def get_lr(self):
  581. raise NotImplementedError
  582. def get_last_lr(self):
  583. return self._last_lr
  584. def step(self,epoch=None):
  585. self._step_count += 1
  586. if epoch is None:
  587. self.last_epoch += 1
  588. values = self.get_lr()
  589. else:
  590. self.last_epoch = epoch
  591. values = self.get_lr()
  592. for i, data in enumerate(zip(self.optimizer.param_groups, values)):
  593. param_group, lr = data
  594. param_group['lr'] = lr
  595. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  596. class LambdaLR(LRScheduler):
  597. def __init__(self, optimizer, lr_lambda, last_epoch=-1):
  598. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  599. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  600. else:
  601. if len(lr_lambda) != len(optimizer.param_groups):
  602. raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda)))
  603. self.lr_lambdas = list(lr_lambda)
  604. super(LambdaLR, self).__init__(optimizer, last_epoch)
  605. def get_lr(self):
  606. return [base_lr * lmbda(self.last_epoch)
  607. for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]

首先冻结OpenAI官方预训练的ViT-B/32版本的CLIP模型中的全部图像层,再利用AdanBelief优化器训练模型,该优化器是Adan优化器和AdaBelief优化器的融合,在Adan优化器中融入"Belief"增强训练模型的泛化性能。

Contributors (1)