# *************************************************************** # Copyright (c) 2023 Jittor. All Rights Reserved. # Maintainers: # Guowei Yang <471184555@qq.com> # Guoye Yang <498731903@qq.com> # Wenyang Zhou <576825820@qq.com> # Meng-Hao Guo # Dun Liang . # # # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** import jittor as jt import numpy as np class Optimizer(object): """ Basic class of Optimizer. Example:: optimizer = nn.SGD(model.parameters(), lr) optimizer.step(loss) """ def __init__(self, params, lr, param_sync_iter=10000): self.param_groups = [] self.lr = lr self.param_sync_iter = param_sync_iter assert len(params) > 0, "Length of parameters should not be zero" if not isinstance(params[0], dict): params = [{'params': params}] for pg in params: assert isinstance(pg, dict) self.param_groups.append(pg) self.n_step = 0 # __zero_grad is a value for fast determ the grad is zero or not # so we can omit 0+x self.__zero_grad = True self._grad_map = {} def add_param_group(self, group): self.param_groups.append(group) def clip_grad_norm(self, max_norm:float, norm_type:int=2): r"""Clips gradient norm of this optimizer. The norm is computed over all gradients together. Args: max_norm (float or int): max norm of the gradients norm_type (int): 1-norm or 2-norm Example:: a = jt.ones(2) opt = jt.optim.SGD([a], 0.1) loss = a*a opt.zero_grad() opt.backward(loss) print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 opt.clip_grad_norm(0.01, 2) print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 opt.step() """ if self.__zero_grad: return grads = [] for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue grads.append(g.flatten()) if len(grads) == 0: return total_norm = jt.norm(jt.concat(grads), norm_type) clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue g.update(g*clip_coef) @property def defaults(self): exclude = set(("defaults", "pre_step", "step")) return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude and not callable(v) } def state_dict(self): state = {"defaults": self.defaults} return state def load_state_dict(self, state): def dfs(x): if isinstance(x, list): for i in range(len(x)): x[i] = dfs(x[i]) elif isinstance(x, dict): for k in x: x[k] = dfs(x[k]) elif isinstance(x, np.ndarray): return jt.array(x).stop_grad() elif isinstance(x, jt.Var): return x.stop_grad() return x exclude = set(("param_groups", "params")) for k, v in state["defaults"].items(): if k not in exclude: setattr(self, k, dfs(v)) param_groups = dfs(state["defaults"].get('param_groups', None)) if param_groups is not None: exclude = set(("params",)) for i in range(len(param_groups)): for k, v in param_groups[i].items(): if k not in exclude: self.param_groups[i][k] = v def zero_grad(self): self.__zero_grad = True def backward(self, loss, retain_graph=False): ''' optimize.backward(loss) is used for accumulate multiple step, it can be used as following: Origin source code :: n_iter = 10000 batch_size = 100 ... for i in range(n_iter): ... loss = calc_loss() optimizer.step(loss) Accumulation version :: n_iter = 10000 batch_size = 100 accumulation_steps = 10 n_iter *= accumulation_steps batch_size //= accumulation_steps ... for i in range(n_iter): ... loss = calc_loss() # if loss is a mean across batch, we need to divide accumulation_steps optimizer.backward(loss / accumulation_steps) if (i+1) % accumulation_steps == 0: optimizer.step() ''' # clean prev grads params = [] params_has_grad = [] for pg in self.param_groups: for p in pg['params']: params.append(p) if not p.is_stop_grad(): params_has_grad.append(p) # sync prev params jt.sync(params_has_grad) # get gradient grads = jt.grad(loss, params_has_grad, retain_graph) # sync grads and model if in mpi if jt.in_mpi: dep = [] def add_dep(v): nonlocal dep v._add_dependency(dep) dep = [v] for g in grads: g.assign(g.mpi_all_reduce("mean")) add_dep(g._input(0)) if self.n_step % self.param_sync_iter == 0: for p in params: p.assign(p.mpi_broadcast()) add_dep(p) self.n_step += 1 # set up grads in param_groups pid = 0 for pg in self.param_groups: if "grads" not in pg: pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ] pg_grads = pg["grads"] for i, p in enumerate(pg['params']): if not p.is_stop_grad(): # accumulate grad and stop grad of grad g = grads[pid].stop_grad() if not self.__zero_grad: g = g + pg_grads[i] pg_grads[i].update(g) pid += 1 self.__zero_grad = False def pre_step(self, loss, retain_graph=False): """ something should be done before step, such as calc gradients, mpi sync, and so on. Example:: class MyOptimizer(Optimizer): def step(self, loss): self.pre_step(loss) ... self.post_step() """ if loss is not None: self.backward(loss, retain_graph) jt.flags.node_order = 1 def post_step(self): """ something should be done before step, such as zero grad, and so on. Example:: class MyOptimizer(Optimizer): def step(self, loss): self.pre_step(loss) ... self.post_step() """ jt.flags.node_order = 0 self.zero_grad() def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) for pg in self.param_groups: lr = pg.get("lr", self.lr) for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue p.update(p - g * lr) self.post_step() def _build_grad_map(self): _grad_map = {} for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): _grad_map[id(p)] = g self._grad_map = _grad_map def find_grad(self, v:jt.Var) -> jt.Var: if id(v) not in self._grad_map: self._build_grad_map() if id(v) not in self._grad_map: raise RuntimeError("This variable is not managed by this optimizer") return self._grad_map[id(v)] def opt_grad(v:jt.Var, opt:Optimizer): ''' Get grad of certain variable in optimizer, Example:: model = Model() optimizer = SGD(model.parameters()) ... optimizer.backward(loss) for p in model.parameters(): grad = p.opt_grad(optimizer) ''' return opt.find_grad(v) jt.Var.opt_grad = opt_grad class SGD(Optimizer): """ SGD Optimizer. Example:: optimizer = nn.SGD(model.parameters(), lr, momentum=0.9) optimizer.step(loss) """ def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False): super().__init__(params, lr) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening self.nesterov = nesterov # initialize required arguments for pg in self.param_groups: values = pg["values"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) def add_param_group(self, group): values = group["values"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group) def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph=False) jt.flags.node_order = 1 for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) momentum = pg.get("momentum", self.momentum) weight_decay = pg.get("weight_decay", self.weight_decay) dampening = pg.get("dampening", self.dampening) nesterov = pg.get("nesterov", self.nesterov) # optimize main body for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): if p.is_stop_grad(): continue dp = p * weight_decay + g v.update(momentum * v + dp * (1 - dampening)) if nesterov: p.update(p - (dp + momentum * v) * lr) else: p.update(p - v * lr) self.post_step() class RMSprop(Optimizer): """ RMSprop Optimizer. Args: params(list): parameters of model. lr(float): learning rate. eps(float): term added to the denominator to avoid division by zero, default 1e-8. alpha(float): smoothing constant, default 0.99. Example: optimizer = nn.RMSprop(model.parameters(), lr) optimizer.step(loss) """ def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99): super().__init__(params, lr) self.eps = eps self.alpha = alpha # initialize required arguments for each param_groups for pg in self.param_groups: values = pg["values"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) def add_param_group(self, group): values = group["values"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group) def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) eps = pg.get("eps", self.eps) alpha = pg.get("alpha", self.alpha) for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): if p.is_stop_grad(): continue v.update(alpha * v + (1-alpha) * g * g) p.update(p - lr * g / (jt.sqrt(v) + eps)) self.post_step() class Adam(Optimizer): """ Adam Optimizer. Example:: optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) optimizer.step(loss) """ def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): super().__init__(params, lr) self.eps = eps self.betas = betas self.weight_decay = weight_decay # assert weight_decay==0, "weight_decay is not supported yet" # initialize required arguments for each param_groups for pg in self.param_groups: values = pg["values"] = [] m = pg["m"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad()) def add_param_group(self, group): values = group["values"] = [] m = group["m"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group) def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) n = float(self.n_step) jt.flags.node_order = 1 for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) eps = pg.get("eps", self.eps) weight_decay = pg.get("weight_decay", self.weight_decay) b0, b1 = pg.get("betas", self.betas) for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): if p.is_stop_grad(): continue g = p * weight_decay + g m.update(b0 * m + (1-b0) * g) v.update(b1 * v + (1-b1) * g * g) step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) p.update(p - m * step_size / (jt.sqrt(v) + eps)) self.post_step() class AdamW(Optimizer): """ AdamW Optimizer. Example:: optimizer = nn.AdamW(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) optimizer.step(loss) """ def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): super().__init__(params, lr) self.eps = eps self.betas = betas self.weight_decay = weight_decay # assert weight_decay==0, "weight_decay is not supported yet" # initialize required arguments for each param_groups for pg in self.param_groups: values = pg["values"] = [] m = pg["m"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad()) def add_param_group(self, group): values = group["values"] = [] m = group["m"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group) def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) n = float(self.n_step) for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) eps = pg.get("eps", self.eps) weight_decay = pg.get("weight_decay", self.weight_decay) b0, b1 = pg.get("betas", self.betas) for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): if p.is_stop_grad(): continue p.update(p * (1 - lr * weight_decay)) bias_correction1 = 1 - b0 ** n bias_correction2 = 1 - b1 ** n m.update(b0 * m + (1-b0) * g) #exp_avg v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps step_size = lr / bias_correction1 p.update(p - step_size * m / denom) self.post_step() class Adan(Optimizer): """ Adan Optimizer. Adan was proposed in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. https://arxiv.org/abs/2208.06677 Adan is an efficient optimizer for most DNN frameworks: - About 2x fewer computational load than SOTAs - Robust to training setting and batch size - Easy to Plug-and-play Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float, flot], optional): coefficients used for first- and second-order moments. (default: (0.98, 0.92, 0.99)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) max_grad_norm (float, optional): value used to clip global grad norm (default: 0.0 no clip) """ def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): super().__init__(params, lr) self.betas = betas self.eps = eps self.weight_decay = weight_decay self.max_grad_norm = max_grad_norm for pg in self.param_groups: pg["m"] = [] pg["v"] = [] pg["d"] = [] pg["pre_grad"] = [] for p in pg["params"]: pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) def add_param_group(self, group): group["m"] = [] group["v"] = [] group["d"] = [] group["pre_grad"] = [] for p in group["params"]: group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group) def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) n = float(self.n_step) for pg in self.param_groups: lr = pg.get("lr", self.lr) betas = pg.get("betas", self.betas) eps = pg.get("eps", self.eps) weight_decay = pg.get("weight_decay", self.weight_decay) max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) beta1, beta2, beta3 = betas bias_correction1 = 1 - beta1 ** n bias_correction2 = 1 - beta2 ** n bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 step_size = lr * bias_correction3_sqrt / bias_correction1 eps_bias_sqrt = eps * bias_correction3_sqrt for p, g, m, v, d, pre_g in zip(pg["params"], pg["grads"], pg["m"], pg["v"], pg["d"], pg["pre_grad"]): if p.is_stop_grad(): continue if self.n_step>0: pre_g.update(g - pre_g) # Update pre_g as grad_diff m.update(beta1 * m + (1 - beta1) * g) d.update(beta2 * d + (1 - beta2) * pre_g) # Use pre_g as grad_diff pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) v.update(beta3 * v + (1 - beta3) * pre_g * pre_g) # Use pre_g as update p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) p.update(p / (1 + lr * weight_decay)) pre_g.update(g) # Update pre_g for the next iteration self.post_step() class AdanBelief(Optimizer): """ Adan Optimizer. Adan was proposed in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. https://arxiv.org/abs/2208.06677 Adan is an efficient optimizer for most DNN frameworks: - About 2x fewer computational load than SOTAs - Robust to training setting and batch size - Easy to Plug-and-play Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float, flot], optional): coefficients used for first- and second-order moments. (default: (0.98, 0.92, 0.99)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) max_grad_norm (float, optional): value used to clip global grad norm (default: 0.0 no clip) """ def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): super().__init__(params, lr) self.betas = betas self.eps = eps self.weight_decay = weight_decay self.max_grad_norm = max_grad_norm for pg in self.param_groups: pg["m"] = [] pg["v"] = [] pg["d"] = [] pg["pre_grad"] = [] for p in pg["params"]: pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) def add_param_group(self, group): group["m"] = [] group["v"] = [] group["d"] = [] group["pre_grad"] = [] for p in group["params"]: group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group) def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) n = float(self.n_step) for pg in self.param_groups: lr = pg.get("lr", self.lr) betas = pg.get("betas", self.betas) eps = pg.get("eps", self.eps) weight_decay = pg.get("weight_decay", self.weight_decay) max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) beta1, beta2, beta3 = betas bias_correction1 = 1 - beta1 ** n bias_correction2 = 1 - beta2 ** n bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 step_size = lr * bias_correction3_sqrt / bias_correction1 eps_bias_sqrt = eps * bias_correction3_sqrt for p, g, m, v, d, pre_g in zip(pg["params"], pg["grads"], pg["m"], pg["v"], pg["d"], pg["pre_grad"]): if p.is_stop_grad(): continue if self.n_step>0: pre_g.update(g - pre_g) # Update pre_g as grad_diff m.update(beta1 * m + (1 - beta1) * g) d.update(beta2 * d + (1 - beta2) * pre_g) # # Use belief as update pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) v.update(beta3 * v + (1 - beta3) * (pre_g - m) * (pre_g - m)) # Use pre_g as update p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) p.update(p / (1 + lr * weight_decay)) #AdanBelief best result 0.7358(300 epoch basic) pre_g.update(g) # Update pre_g for the next iteration self.post_step() class LRScheduler: def __init__(self,optimizer, last_epoch=-1): assert isinstance(optimizer,Optimizer) self.optimizer = optimizer if last_epoch==-1: for gp in optimizer.param_groups: gp.setdefault('initial_lr',gp.get('lr',optimizer.lr)) else: for gp in optimizer.param_groups: assert 'initial_lr' in gp self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) self.last_epoch = last_epoch self.optimizer._step_count = 0 self._step_count = 0 self.step() def get_lr(self): raise NotImplementedError def get_last_lr(self): return self._last_lr def step(self,epoch=None): self._step_count += 1 if epoch is None: self.last_epoch += 1 values = self.get_lr() else: self.last_epoch = epoch values = self.get_lr() for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data param_group['lr'] = lr self._last_lr = [group['lr'] for group in self.optimizer.param_groups] class LambdaLR(LRScheduler): def __init__(self, optimizer, lr_lambda, last_epoch=-1): if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: if len(lr_lambda) != len(optimizer.param_groups): raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda))) self.lr_lambdas = list(lr_lambda) super(LambdaLR, self).__init__(optimizer, last_epoch) def get_lr(self): return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]