| @@ -0,0 +1,729 @@ | |||||
| # *************************************************************** | |||||
| # Copyright (c) 2023 Jittor. All Rights Reserved. | |||||
| # Maintainers: | |||||
| # Guowei Yang <471184555@qq.com> | |||||
| # Guoye Yang <498731903@qq.com> | |||||
| # Wenyang Zhou <576825820@qq.com> | |||||
| # Meng-Hao Guo <guomenghao1997@gmail.com> | |||||
| # Dun Liang <randonlang@gmail.com>. | |||||
| # | |||||
| # | |||||
| # This file is subject to the terms and conditions defined in | |||||
| # file 'LICENSE.txt', which is part of this source code package. | |||||
| # *************************************************************** | |||||
| import jittor as jt | |||||
| import numpy as np | |||||
| class Optimizer(object): | |||||
| """ Basic class of Optimizer. | |||||
| Example:: | |||||
| optimizer = nn.SGD(model.parameters(), lr) | |||||
| optimizer.step(loss) | |||||
| """ | |||||
| def __init__(self, params, lr, param_sync_iter=10000): | |||||
| self.param_groups = [] | |||||
| self.lr = lr | |||||
| self.param_sync_iter = param_sync_iter | |||||
| assert len(params) > 0, "Length of parameters should not be zero" | |||||
| if not isinstance(params[0], dict): | |||||
| params = [{'params': params}] | |||||
| for pg in params: | |||||
| assert isinstance(pg, dict) | |||||
| self.param_groups.append(pg) | |||||
| self.n_step = 0 | |||||
| # __zero_grad is a value for fast determ the grad is zero or not | |||||
| # so we can omit 0+x | |||||
| self.__zero_grad = True | |||||
| self._grad_map = {} | |||||
| def add_param_group(self, group): | |||||
| self.param_groups.append(group) | |||||
| def clip_grad_norm(self, max_norm:float, norm_type:int=2): | |||||
| r"""Clips gradient norm of this optimizer. | |||||
| The norm is computed over all gradients together. | |||||
| Args: | |||||
| max_norm (float or int): max norm of the gradients | |||||
| norm_type (int): 1-norm or 2-norm | |||||
| Example:: | |||||
| a = jt.ones(2) | |||||
| opt = jt.optim.SGD([a], 0.1) | |||||
| loss = a*a | |||||
| opt.zero_grad() | |||||
| opt.backward(loss) | |||||
| print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 | |||||
| opt.clip_grad_norm(0.01, 2) | |||||
| print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 | |||||
| opt.step() | |||||
| """ | |||||
| if self.__zero_grad: return | |||||
| grads = [] | |||||
| for pg in self.param_groups: | |||||
| for p, g in zip(pg["params"], pg["grads"]): | |||||
| if p.is_stop_grad(): continue | |||||
| grads.append(g.flatten()) | |||||
| if len(grads) == 0: return | |||||
| total_norm = jt.norm(jt.concat(grads), norm_type) | |||||
| clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) | |||||
| for pg in self.param_groups: | |||||
| for p, g in zip(pg["params"], pg["grads"]): | |||||
| if p.is_stop_grad(): continue | |||||
| g.update(g*clip_coef) | |||||
| @property | |||||
| def defaults(self): | |||||
| exclude = set(("defaults", "pre_step", "step")) | |||||
| return { k:v for k, v in self.__dict__.items() | |||||
| if k[0] != '_' and k not in exclude and not callable(v) } | |||||
| def state_dict(self): | |||||
| state = {"defaults": self.defaults} | |||||
| return state | |||||
| def load_state_dict(self, state): | |||||
| def dfs(x): | |||||
| if isinstance(x, list): | |||||
| for i in range(len(x)): | |||||
| x[i] = dfs(x[i]) | |||||
| elif isinstance(x, dict): | |||||
| for k in x: | |||||
| x[k] = dfs(x[k]) | |||||
| elif isinstance(x, np.ndarray): | |||||
| return jt.array(x).stop_grad() | |||||
| elif isinstance(x, jt.Var): | |||||
| return x.stop_grad() | |||||
| return x | |||||
| exclude = set(("param_groups", "params")) | |||||
| for k, v in state["defaults"].items(): | |||||
| if k not in exclude: | |||||
| setattr(self, k, dfs(v)) | |||||
| param_groups = dfs(state["defaults"].get('param_groups', None)) | |||||
| if param_groups is not None: | |||||
| exclude = set(("params",)) | |||||
| for i in range(len(param_groups)): | |||||
| for k, v in param_groups[i].items(): | |||||
| if k not in exclude: | |||||
| self.param_groups[i][k] = v | |||||
| def zero_grad(self): | |||||
| self.__zero_grad = True | |||||
| def backward(self, loss, retain_graph=False): | |||||
| ''' | |||||
| optimize.backward(loss) is used for accumulate multiple step, | |||||
| it can be used as following: | |||||
| Origin source code :: | |||||
| n_iter = 10000 | |||||
| batch_size = 100 | |||||
| ... | |||||
| for i in range(n_iter): | |||||
| ... | |||||
| loss = calc_loss() | |||||
| optimizer.step(loss) | |||||
| Accumulation version :: | |||||
| n_iter = 10000 | |||||
| batch_size = 100 | |||||
| accumulation_steps = 10 | |||||
| n_iter *= accumulation_steps | |||||
| batch_size //= accumulation_steps | |||||
| ... | |||||
| for i in range(n_iter): | |||||
| ... | |||||
| loss = calc_loss() | |||||
| # if loss is a mean across batch, we need to divide accumulation_steps | |||||
| optimizer.backward(loss / accumulation_steps) | |||||
| if (i+1) % accumulation_steps == 0: | |||||
| optimizer.step() | |||||
| ''' | |||||
| # clean prev grads | |||||
| params = [] | |||||
| params_has_grad = [] | |||||
| for pg in self.param_groups: | |||||
| for p in pg['params']: | |||||
| params.append(p) | |||||
| if not p.is_stop_grad(): | |||||
| params_has_grad.append(p) | |||||
| # sync prev params | |||||
| jt.sync(params_has_grad) | |||||
| # get gradient | |||||
| grads = jt.grad(loss, params_has_grad, retain_graph) | |||||
| # sync grads and model if in mpi | |||||
| if jt.in_mpi: | |||||
| dep = [] | |||||
| def add_dep(v): | |||||
| nonlocal dep | |||||
| v._add_dependency(dep) | |||||
| dep = [v] | |||||
| for g in grads: | |||||
| g.assign(g.mpi_all_reduce("mean")) | |||||
| add_dep(g._input(0)) | |||||
| if self.n_step % self.param_sync_iter == 0: | |||||
| for p in params: | |||||
| p.assign(p.mpi_broadcast()) | |||||
| add_dep(p) | |||||
| self.n_step += 1 | |||||
| # set up grads in param_groups | |||||
| pid = 0 | |||||
| for pg in self.param_groups: | |||||
| if "grads" not in pg: | |||||
| pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ] | |||||
| pg_grads = pg["grads"] | |||||
| for i, p in enumerate(pg['params']): | |||||
| if not p.is_stop_grad(): | |||||
| # accumulate grad and stop grad of grad | |||||
| g = grads[pid].stop_grad() | |||||
| if not self.__zero_grad: | |||||
| g = g + pg_grads[i] | |||||
| pg_grads[i].update(g) | |||||
| pid += 1 | |||||
| self.__zero_grad = False | |||||
| def pre_step(self, loss, retain_graph=False): | |||||
| """ something should be done before step, such as calc gradients, mpi sync, and so on. | |||||
| Example:: | |||||
| class MyOptimizer(Optimizer): | |||||
| def step(self, loss): | |||||
| self.pre_step(loss) | |||||
| ... | |||||
| self.post_step() | |||||
| """ | |||||
| if loss is not None: | |||||
| self.backward(loss, retain_graph) | |||||
| jt.flags.node_order = 1 | |||||
| def post_step(self): | |||||
| """ something should be done before step, such as zero grad, and so on. | |||||
| Example:: | |||||
| class MyOptimizer(Optimizer): | |||||
| def step(self, loss): | |||||
| self.pre_step(loss) | |||||
| ... | |||||
| self.post_step() | |||||
| """ | |||||
| jt.flags.node_order = 0 | |||||
| self.zero_grad() | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph) | |||||
| for pg in self.param_groups: | |||||
| lr = pg.get("lr", self.lr) | |||||
| for p, g in zip(pg["params"], pg["grads"]): | |||||
| if p.is_stop_grad(): continue | |||||
| p.update(p - g * lr) | |||||
| self.post_step() | |||||
| def _build_grad_map(self): | |||||
| _grad_map = {} | |||||
| for pg in self.param_groups: | |||||
| for p, g in zip(pg["params"], pg["grads"]): | |||||
| _grad_map[id(p)] = g | |||||
| self._grad_map = _grad_map | |||||
| def find_grad(self, v:jt.Var) -> jt.Var: | |||||
| if id(v) not in self._grad_map: | |||||
| self._build_grad_map() | |||||
| if id(v) not in self._grad_map: | |||||
| raise RuntimeError("This variable is not managed by this optimizer") | |||||
| return self._grad_map[id(v)] | |||||
| def opt_grad(v:jt.Var, opt:Optimizer): | |||||
| ''' Get grad of certain variable in optimizer, Example:: | |||||
| model = Model() | |||||
| optimizer = SGD(model.parameters()) | |||||
| ... | |||||
| optimizer.backward(loss) | |||||
| for p in model.parameters(): | |||||
| grad = p.opt_grad(optimizer) | |||||
| ''' | |||||
| return opt.find_grad(v) | |||||
| jt.Var.opt_grad = opt_grad | |||||
| class SGD(Optimizer): | |||||
| """ SGD Optimizer. | |||||
| Example:: | |||||
| optimizer = nn.SGD(model.parameters(), lr, momentum=0.9) | |||||
| optimizer.step(loss) | |||||
| """ | |||||
| def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False): | |||||
| super().__init__(params, lr) | |||||
| self.momentum = momentum | |||||
| self.weight_decay = weight_decay | |||||
| self.dampening = dampening | |||||
| self.nesterov = nesterov | |||||
| # initialize required arguments | |||||
| for pg in self.param_groups: | |||||
| values = pg["values"] = [] | |||||
| for p in pg["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| def add_param_group(self, group): | |||||
| values = group["values"] = [] | |||||
| for p in group["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| self.param_groups.append(group) | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph=False) | |||||
| jt.flags.node_order = 1 | |||||
| for pg in self.param_groups: | |||||
| # get arguments from each param_groups | |||||
| lr = pg.get("lr", self.lr) | |||||
| momentum = pg.get("momentum", self.momentum) | |||||
| weight_decay = pg.get("weight_decay", self.weight_decay) | |||||
| dampening = pg.get("dampening", self.dampening) | |||||
| nesterov = pg.get("nesterov", self.nesterov) | |||||
| # optimize main body | |||||
| for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): | |||||
| if p.is_stop_grad(): continue | |||||
| dp = p * weight_decay + g | |||||
| v.update(momentum * v + dp * (1 - dampening)) | |||||
| if nesterov: | |||||
| p.update(p - (dp + momentum * v) * lr) | |||||
| else: | |||||
| p.update(p - v * lr) | |||||
| self.post_step() | |||||
| class RMSprop(Optimizer): | |||||
| """ RMSprop Optimizer. | |||||
| Args: | |||||
| params(list): parameters of model. | |||||
| lr(float): learning rate. | |||||
| eps(float): term added to the denominator to avoid division by zero, default 1e-8. | |||||
| alpha(float): smoothing constant, default 0.99. | |||||
| Example: | |||||
| optimizer = nn.RMSprop(model.parameters(), lr) | |||||
| optimizer.step(loss) | |||||
| """ | |||||
| def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99): | |||||
| super().__init__(params, lr) | |||||
| self.eps = eps | |||||
| self.alpha = alpha | |||||
| # initialize required arguments for each param_groups | |||||
| for pg in self.param_groups: | |||||
| values = pg["values"] = [] | |||||
| for p in pg["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| def add_param_group(self, group): | |||||
| values = group["values"] = [] | |||||
| for p in group["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| self.param_groups.append(group) | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph) | |||||
| for pg in self.param_groups: | |||||
| # get arguments from each param_groups | |||||
| lr = pg.get("lr", self.lr) | |||||
| eps = pg.get("eps", self.eps) | |||||
| alpha = pg.get("alpha", self.alpha) | |||||
| for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): | |||||
| if p.is_stop_grad(): continue | |||||
| v.update(alpha * v + (1-alpha) * g * g) | |||||
| p.update(p - lr * g / (jt.sqrt(v) + eps)) | |||||
| self.post_step() | |||||
| class Adam(Optimizer): | |||||
| """ Adam Optimizer. | |||||
| Example:: | |||||
| optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) | |||||
| optimizer.step(loss) | |||||
| """ | |||||
| def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): | |||||
| super().__init__(params, lr) | |||||
| self.eps = eps | |||||
| self.betas = betas | |||||
| self.weight_decay = weight_decay | |||||
| # assert weight_decay==0, "weight_decay is not supported yet" | |||||
| # initialize required arguments for each param_groups | |||||
| for pg in self.param_groups: | |||||
| values = pg["values"] = [] | |||||
| m = pg["m"] = [] | |||||
| for p in pg["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| m.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| def add_param_group(self, group): | |||||
| values = group["values"] = [] | |||||
| m = group["m"] = [] | |||||
| for p in group["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| m.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| self.param_groups.append(group) | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph) | |||||
| n = float(self.n_step) | |||||
| jt.flags.node_order = 1 | |||||
| for pg in self.param_groups: | |||||
| # get arguments from each param_groups | |||||
| lr = pg.get("lr", self.lr) | |||||
| eps = pg.get("eps", self.eps) | |||||
| weight_decay = pg.get("weight_decay", self.weight_decay) | |||||
| b0, b1 = pg.get("betas", self.betas) | |||||
| for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): | |||||
| if p.is_stop_grad(): continue | |||||
| g = p * weight_decay + g | |||||
| m.update(b0 * m + (1-b0) * g) | |||||
| v.update(b1 * v + (1-b1) * g * g) | |||||
| step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) | |||||
| p.update(p - m * step_size / (jt.sqrt(v) + eps)) | |||||
| self.post_step() | |||||
| class AdamW(Optimizer): | |||||
| """ AdamW Optimizer. | |||||
| Example:: | |||||
| optimizer = nn.AdamW(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) | |||||
| optimizer.step(loss) | |||||
| """ | |||||
| def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): | |||||
| super().__init__(params, lr) | |||||
| self.eps = eps | |||||
| self.betas = betas | |||||
| self.weight_decay = weight_decay | |||||
| # assert weight_decay==0, "weight_decay is not supported yet" | |||||
| # initialize required arguments for each param_groups | |||||
| for pg in self.param_groups: | |||||
| values = pg["values"] = [] | |||||
| m = pg["m"] = [] | |||||
| for p in pg["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| m.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| def add_param_group(self, group): | |||||
| values = group["values"] = [] | |||||
| m = group["m"] = [] | |||||
| for p in group["params"]: | |||||
| values.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| m.append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| self.param_groups.append(group) | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph) | |||||
| n = float(self.n_step) | |||||
| for pg in self.param_groups: | |||||
| # get arguments from each param_groups | |||||
| lr = pg.get("lr", self.lr) | |||||
| eps = pg.get("eps", self.eps) | |||||
| weight_decay = pg.get("weight_decay", self.weight_decay) | |||||
| b0, b1 = pg.get("betas", self.betas) | |||||
| for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): | |||||
| if p.is_stop_grad(): continue | |||||
| p.update(p * (1 - lr * weight_decay)) | |||||
| bias_correction1 = 1 - b0 ** n | |||||
| bias_correction2 = 1 - b1 ** n | |||||
| m.update(b0 * m + (1-b0) * g) #exp_avg | |||||
| v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq | |||||
| denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps | |||||
| step_size = lr / bias_correction1 | |||||
| p.update(p - step_size * m / denom) | |||||
| self.post_step() | |||||
| class Adan(Optimizer): | |||||
| """ Adan Optimizer. | |||||
| Adan was proposed in | |||||
| Adan: Adaptive Nesterov Momentum Algorithm for | |||||
| Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. | |||||
| https://arxiv.org/abs/2208.06677 | |||||
| Adan is an efficient optimizer for most DNN frameworks: | |||||
| - About 2x fewer computational load than SOTAs | |||||
| - Robust to training setting and batch size | |||||
| - Easy to Plug-and-play | |||||
| Arguments: | |||||
| params (iterable): iterable of parameters to optimize or | |||||
| dicts defining parameter groups. | |||||
| lr (float, optional): learning rate. (default: 1e-3) | |||||
| betas (Tuple[float, float, flot], optional): coefficients used for | |||||
| first- and second-order moments. (default: (0.98, 0.92, 0.99)) | |||||
| eps (float, optional): term added to the denominator to improve | |||||
| numerical stability. (default: 1e-8) | |||||
| weight_decay (float, optional): decoupled weight decay | |||||
| (L2 penalty) (default: 0) | |||||
| max_grad_norm (float, optional): value used to clip | |||||
| global grad norm (default: 0.0 no clip) | |||||
| """ | |||||
| def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), | |||||
| eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): | |||||
| super().__init__(params, lr) | |||||
| self.betas = betas | |||||
| self.eps = eps | |||||
| self.weight_decay = weight_decay | |||||
| self.max_grad_norm = max_grad_norm | |||||
| for pg in self.param_groups: | |||||
| pg["m"] = [] | |||||
| pg["v"] = [] | |||||
| pg["d"] = [] | |||||
| pg["pre_grad"] = [] | |||||
| for p in pg["params"]: | |||||
| pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| def add_param_group(self, group): | |||||
| group["m"] = [] | |||||
| group["v"] = [] | |||||
| group["d"] = [] | |||||
| group["pre_grad"] = [] | |||||
| for p in group["params"]: | |||||
| group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| self.param_groups.append(group) | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph) | |||||
| n = float(self.n_step) | |||||
| for pg in self.param_groups: | |||||
| lr = pg.get("lr", self.lr) | |||||
| betas = pg.get("betas", self.betas) | |||||
| eps = pg.get("eps", self.eps) | |||||
| weight_decay = pg.get("weight_decay", self.weight_decay) | |||||
| max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) | |||||
| if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) | |||||
| beta1, beta2, beta3 = betas | |||||
| bias_correction1 = 1 - beta1 ** n | |||||
| bias_correction2 = 1 - beta2 ** n | |||||
| bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) | |||||
| step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 | |||||
| step_size = lr * bias_correction3_sqrt / bias_correction1 | |||||
| eps_bias_sqrt = eps * bias_correction3_sqrt | |||||
| for p, g, m, v, d, pre_g in zip(pg["params"], | |||||
| pg["grads"], | |||||
| pg["m"], | |||||
| pg["v"], | |||||
| pg["d"], | |||||
| pg["pre_grad"]): | |||||
| if p.is_stop_grad(): continue | |||||
| if self.n_step>0: | |||||
| pre_g.update(g - pre_g) # Update pre_g as grad_diff | |||||
| m.update(beta1 * m + (1 - beta1) * g) | |||||
| d.update(beta2 * d + (1 - beta2) * pre_g) # Use pre_g as grad_diff | |||||
| pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) | |||||
| v.update(beta3 * v + (1 - beta3) * pre_g * pre_g) # Use pre_g as update | |||||
| p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) | |||||
| p.update(p / (1 + lr * weight_decay)) | |||||
| pre_g.update(g) # Update pre_g for the next iteration | |||||
| self.post_step() | |||||
| class AdanBelief(Optimizer): | |||||
| """ Adan Optimizer. | |||||
| Adan was proposed in | |||||
| Adan: Adaptive Nesterov Momentum Algorithm for | |||||
| Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. | |||||
| https://arxiv.org/abs/2208.06677 | |||||
| Adan is an efficient optimizer for most DNN frameworks: | |||||
| - About 2x fewer computational load than SOTAs | |||||
| - Robust to training setting and batch size | |||||
| - Easy to Plug-and-play | |||||
| Arguments: | |||||
| params (iterable): iterable of parameters to optimize or | |||||
| dicts defining parameter groups. | |||||
| lr (float, optional): learning rate. (default: 1e-3) | |||||
| betas (Tuple[float, float, flot], optional): coefficients used for | |||||
| first- and second-order moments. (default: (0.98, 0.92, 0.99)) | |||||
| eps (float, optional): term added to the denominator to improve | |||||
| numerical stability. (default: 1e-8) | |||||
| weight_decay (float, optional): decoupled weight decay | |||||
| (L2 penalty) (default: 0) | |||||
| max_grad_norm (float, optional): value used to clip | |||||
| global grad norm (default: 0.0 no clip) | |||||
| """ | |||||
| def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), | |||||
| eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): | |||||
| super().__init__(params, lr) | |||||
| self.betas = betas | |||||
| self.eps = eps | |||||
| self.weight_decay = weight_decay | |||||
| self.max_grad_norm = max_grad_norm | |||||
| for pg in self.param_groups: | |||||
| pg["m"] = [] | |||||
| pg["v"] = [] | |||||
| pg["d"] = [] | |||||
| pg["pre_grad"] = [] | |||||
| for p in pg["params"]: | |||||
| pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| def add_param_group(self, group): | |||||
| group["m"] = [] | |||||
| group["v"] = [] | |||||
| group["d"] = [] | |||||
| group["pre_grad"] = [] | |||||
| for p in group["params"]: | |||||
| group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) | |||||
| self.param_groups.append(group) | |||||
| def step(self, loss=None, retain_graph=False): | |||||
| self.pre_step(loss, retain_graph) | |||||
| n = float(self.n_step) | |||||
| for pg in self.param_groups: | |||||
| lr = pg.get("lr", self.lr) | |||||
| betas = pg.get("betas", self.betas) | |||||
| eps = pg.get("eps", self.eps) | |||||
| weight_decay = pg.get("weight_decay", self.weight_decay) | |||||
| max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) | |||||
| if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) | |||||
| beta1, beta2, beta3 = betas | |||||
| bias_correction1 = 1 - beta1 ** n | |||||
| bias_correction2 = 1 - beta2 ** n | |||||
| bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) | |||||
| step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 | |||||
| step_size = lr * bias_correction3_sqrt / bias_correction1 | |||||
| eps_bias_sqrt = eps * bias_correction3_sqrt | |||||
| for p, g, m, v, d, pre_g in zip(pg["params"], | |||||
| pg["grads"], | |||||
| pg["m"], | |||||
| pg["v"], | |||||
| pg["d"], | |||||
| pg["pre_grad"]): | |||||
| if p.is_stop_grad(): continue | |||||
| if self.n_step>0: | |||||
| pre_g.update(g - pre_g) # Update pre_g as grad_diff | |||||
| m.update(beta1 * m + (1 - beta1) * g) | |||||
| d.update(beta2 * d + (1 - beta2) * pre_g) # # Use belief as update | |||||
| pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) | |||||
| v.update(beta3 * v + (1 - beta3) * (pre_g - m) * (pre_g - m)) # Use pre_g as update | |||||
| p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) | |||||
| p.update(p / (1 + lr * weight_decay)) #AdanBelief best result 0.7358(300 epoch basic) | |||||
| pre_g.update(g) # Update pre_g for the next iteration | |||||
| self.post_step() | |||||
| class LRScheduler: | |||||
| def __init__(self,optimizer, last_epoch=-1): | |||||
| assert isinstance(optimizer,Optimizer) | |||||
| self.optimizer = optimizer | |||||
| if last_epoch==-1: | |||||
| for gp in optimizer.param_groups: | |||||
| gp.setdefault('initial_lr',gp.get('lr',optimizer.lr)) | |||||
| else: | |||||
| for gp in optimizer.param_groups: | |||||
| assert 'initial_lr' in gp | |||||
| self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) | |||||
| self.last_epoch = last_epoch | |||||
| self.optimizer._step_count = 0 | |||||
| self._step_count = 0 | |||||
| self.step() | |||||
| def get_lr(self): | |||||
| raise NotImplementedError | |||||
| def get_last_lr(self): | |||||
| return self._last_lr | |||||
| def step(self,epoch=None): | |||||
| self._step_count += 1 | |||||
| if epoch is None: | |||||
| self.last_epoch += 1 | |||||
| values = self.get_lr() | |||||
| else: | |||||
| self.last_epoch = epoch | |||||
| values = self.get_lr() | |||||
| for i, data in enumerate(zip(self.optimizer.param_groups, values)): | |||||
| param_group, lr = data | |||||
| param_group['lr'] = lr | |||||
| self._last_lr = [group['lr'] for group in self.optimizer.param_groups] | |||||
| class LambdaLR(LRScheduler): | |||||
| def __init__(self, optimizer, lr_lambda, last_epoch=-1): | |||||
| if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): | |||||
| self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) | |||||
| else: | |||||
| if len(lr_lambda) != len(optimizer.param_groups): | |||||
| raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda))) | |||||
| self.lr_lambdas = list(lr_lambda) | |||||
| super(LambdaLR, self).__init__(optimizer, last_epoch) | |||||
| def get_lr(self): | |||||
| return [base_lr * lmbda(self.last_epoch) | |||||
| for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] | |||||