From f4f3192acce0d64980abb63108c037e6bab8ee3d Mon Sep 17 00:00:00 2001 From: BIT2024 Date: Wed, 21 Aug 2024 11:05:11 +0800 Subject: [PATCH] ADD file via upload --- optim.py | 729 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 729 insertions(+) create mode 100644 optim.py diff --git a/optim.py b/optim.py new file mode 100644 index 0000000..11620d4 --- /dev/null +++ b/optim.py @@ -0,0 +1,729 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Guoye Yang <498731903@qq.com> +# Wenyang Zhou <576825820@qq.com> +# Meng-Hao Guo +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np + +class Optimizer(object): + """ Basic class of Optimizer. + + Example:: + + optimizer = nn.SGD(model.parameters(), lr) + optimizer.step(loss) + """ + def __init__(self, params, lr, param_sync_iter=10000): + self.param_groups = [] + self.lr = lr + self.param_sync_iter = param_sync_iter + + assert len(params) > 0, "Length of parameters should not be zero" + if not isinstance(params[0], dict): + params = [{'params': params}] + for pg in params: + assert isinstance(pg, dict) + self.param_groups.append(pg) + self.n_step = 0 + # __zero_grad is a value for fast determ the grad is zero or not + # so we can omit 0+x + self.__zero_grad = True + self._grad_map = {} + + def add_param_group(self, group): + self.param_groups.append(group) + + def clip_grad_norm(self, max_norm:float, norm_type:int=2): + r"""Clips gradient norm of this optimizer. + The norm is computed over all gradients together. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (int): 1-norm or 2-norm + + Example:: + + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + + loss = a*a + opt.zero_grad() + opt.backward(loss) + + print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 + opt.clip_grad_norm(0.01, 2) + print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 + + opt.step() + + """ + if self.__zero_grad: return + grads = [] + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + grads.append(g.flatten()) + if len(grads) == 0: return + total_norm = jt.norm(jt.concat(grads), norm_type) + clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + g.update(g*clip_coef) + + @property + def defaults(self): + exclude = set(("defaults", "pre_step", "step")) + return { k:v for k, v in self.__dict__.items() + if k[0] != '_' and k not in exclude and not callable(v) } + + def state_dict(self): + state = {"defaults": self.defaults} + return state + + def load_state_dict(self, state): + + def dfs(x): + if isinstance(x, list): + for i in range(len(x)): + x[i] = dfs(x[i]) + elif isinstance(x, dict): + for k in x: + x[k] = dfs(x[k]) + elif isinstance(x, np.ndarray): + return jt.array(x).stop_grad() + elif isinstance(x, jt.Var): + return x.stop_grad() + return x + + exclude = set(("param_groups", "params")) + for k, v in state["defaults"].items(): + if k not in exclude: + setattr(self, k, dfs(v)) + param_groups = dfs(state["defaults"].get('param_groups', None)) + if param_groups is not None: + exclude = set(("params",)) + for i in range(len(param_groups)): + for k, v in param_groups[i].items(): + if k not in exclude: + self.param_groups[i][k] = v + + + + def zero_grad(self): + self.__zero_grad = True + + def backward(self, loss, retain_graph=False): + ''' + optimize.backward(loss) is used for accumulate multiple step, + it can be used as following: + + Origin source code :: + + n_iter = 10000 + batch_size = 100 + ... + for i in range(n_iter): + ... + loss = calc_loss() + optimizer.step(loss) + + Accumulation version :: + + n_iter = 10000 + batch_size = 100 + accumulation_steps = 10 + n_iter *= accumulation_steps + batch_size //= accumulation_steps + ... + for i in range(n_iter): + ... + loss = calc_loss() + # if loss is a mean across batch, we need to divide accumulation_steps + optimizer.backward(loss / accumulation_steps) + if (i+1) % accumulation_steps == 0: + optimizer.step() + + + ''' + # clean prev grads + params = [] + params_has_grad = [] + for pg in self.param_groups: + for p in pg['params']: + params.append(p) + if not p.is_stop_grad(): + params_has_grad.append(p) + + # sync prev params + jt.sync(params_has_grad) + + # get gradient + grads = jt.grad(loss, params_has_grad, retain_graph) + + # sync grads and model if in mpi + if jt.in_mpi: + dep = [] + def add_dep(v): + nonlocal dep + v._add_dependency(dep) + dep = [v] + + for g in grads: + g.assign(g.mpi_all_reduce("mean")) + add_dep(g._input(0)) + if self.n_step % self.param_sync_iter == 0: + for p in params: + p.assign(p.mpi_broadcast()) + add_dep(p) + self.n_step += 1 + + # set up grads in param_groups + pid = 0 + for pg in self.param_groups: + if "grads" not in pg: + pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ] + pg_grads = pg["grads"] + for i, p in enumerate(pg['params']): + if not p.is_stop_grad(): + # accumulate grad and stop grad of grad + g = grads[pid].stop_grad() + if not self.__zero_grad: + g = g + pg_grads[i] + pg_grads[i].update(g) + pid += 1 + self.__zero_grad = False + + def pre_step(self, loss, retain_graph=False): + """ something should be done before step, such as calc gradients, mpi sync, and so on. + + Example:: + + class MyOptimizer(Optimizer): + def step(self, loss): + self.pre_step(loss) + ... + self.post_step() + """ + if loss is not None: + self.backward(loss, retain_graph) + jt.flags.node_order = 1 + + def post_step(self): + """ something should be done before step, such as zero grad, and so on. + + Example:: + + class MyOptimizer(Optimizer): + def step(self, loss): + self.pre_step(loss) + ... + self.post_step() + """ + jt.flags.node_order = 0 + self.zero_grad() + + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + for pg in self.param_groups: + lr = pg.get("lr", self.lr) + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + p.update(p - g * lr) + self.post_step() + + def _build_grad_map(self): + _grad_map = {} + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + _grad_map[id(p)] = g + self._grad_map = _grad_map + + def find_grad(self, v:jt.Var) -> jt.Var: + if id(v) not in self._grad_map: + self._build_grad_map() + if id(v) not in self._grad_map: + raise RuntimeError("This variable is not managed by this optimizer") + return self._grad_map[id(v)] + +def opt_grad(v:jt.Var, opt:Optimizer): + ''' Get grad of certain variable in optimizer, Example:: + + + model = Model() + optimizer = SGD(model.parameters()) + ... + optimizer.backward(loss) + + for p in model.parameters(): + grad = p.opt_grad(optimizer) + ''' + return opt.find_grad(v) + +jt.Var.opt_grad = opt_grad + +class SGD(Optimizer): + """ SGD Optimizer. + + Example:: + + optimizer = nn.SGD(model.parameters(), lr, momentum=0.9) + optimizer.step(loss) + """ + def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False): + super().__init__(params, lr) + self.momentum = momentum + self.weight_decay = weight_decay + self.dampening = dampening + self.nesterov = nesterov + + # initialize required arguments + for pg in self.param_groups: + values = pg["values"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph=False) + jt.flags.node_order = 1 + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + momentum = pg.get("momentum", self.momentum) + weight_decay = pg.get("weight_decay", self.weight_decay) + dampening = pg.get("dampening", self.dampening) + nesterov = pg.get("nesterov", self.nesterov) + + # optimize main body + for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): + if p.is_stop_grad(): continue + dp = p * weight_decay + g + v.update(momentum * v + dp * (1 - dampening)) + if nesterov: + p.update(p - (dp + momentum * v) * lr) + else: + p.update(p - v * lr) + self.post_step() + +class RMSprop(Optimizer): + """ RMSprop Optimizer. + Args: + params(list): parameters of model. + lr(float): learning rate. + eps(float): term added to the denominator to avoid division by zero, default 1e-8. + alpha(float): smoothing constant, default 0.99. + + Example: + optimizer = nn.RMSprop(model.parameters(), lr) + optimizer.step(loss) + """ + def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99): + super().__init__(params, lr) + self.eps = eps + self.alpha = alpha + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + alpha = pg.get("alpha", self.alpha) + for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): + if p.is_stop_grad(): continue + v.update(alpha * v + (1-alpha) * g * g) + p.update(p - lr * g / (jt.sqrt(v) + eps)) + self.post_step() + +class Adam(Optimizer): + """ Adam Optimizer. + + Example:: + + optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) + optimizer.step(loss) + """ + def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): + super().__init__(params, lr) + self.eps = eps + self.betas = betas + self.weight_decay = weight_decay + # assert weight_decay==0, "weight_decay is not supported yet" + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + m = pg["m"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + m = group["m"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + jt.flags.node_order = 1 + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + b0, b1 = pg.get("betas", self.betas) + for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): + if p.is_stop_grad(): continue + g = p * weight_decay + g + m.update(b0 * m + (1-b0) * g) + v.update(b1 * v + (1-b1) * g * g) + step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) + p.update(p - m * step_size / (jt.sqrt(v) + eps)) + self.post_step() + +class AdamW(Optimizer): + """ AdamW Optimizer. + + Example:: + + optimizer = nn.AdamW(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) + optimizer.step(loss) + """ + def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): + super().__init__(params, lr) + self.eps = eps + self.betas = betas + self.weight_decay = weight_decay + # assert weight_decay==0, "weight_decay is not supported yet" + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + m = pg["m"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + m = group["m"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + b0, b1 = pg.get("betas", self.betas) + for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): + if p.is_stop_grad(): continue + p.update(p * (1 - lr * weight_decay)) + bias_correction1 = 1 - b0 ** n + bias_correction2 = 1 - b1 ** n + m.update(b0 * m + (1-b0) * g) #exp_avg + v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq + denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps + step_size = lr / bias_correction1 + p.update(p - step_size * m / denom) + self.post_step() + +class Adan(Optimizer): + """ Adan Optimizer. + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Adan is an efficient optimizer for most DNN frameworks: + - About 2x fewer computational load than SOTAs + - Robust to training setting and batch size + - Easy to Plug-and-play + + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + """ + def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), + eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): + super().__init__(params, lr) + self.betas = betas + self.eps = eps + self.weight_decay = weight_decay + self.max_grad_norm = max_grad_norm + + for pg in self.param_groups: + pg["m"] = [] + pg["v"] = [] + pg["d"] = [] + pg["pre_grad"] = [] + for p in pg["params"]: + pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + + + def add_param_group(self, group): + group["m"] = [] + group["v"] = [] + group["d"] = [] + group["pre_grad"] = [] + for p in group["params"]: + group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + for pg in self.param_groups: + lr = pg.get("lr", self.lr) + betas = pg.get("betas", self.betas) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) + if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) + beta1, beta2, beta3 = betas + + bias_correction1 = 1 - beta1 ** n + bias_correction2 = 1 - beta2 ** n + bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) + + + step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 + step_size = lr * bias_correction3_sqrt / bias_correction1 + eps_bias_sqrt = eps * bias_correction3_sqrt + + for p, g, m, v, d, pre_g in zip(pg["params"], + pg["grads"], + pg["m"], + pg["v"], + pg["d"], + pg["pre_grad"]): + if p.is_stop_grad(): continue + + if self.n_step>0: + pre_g.update(g - pre_g) # Update pre_g as grad_diff + + + m.update(beta1 * m + (1 - beta1) * g) + d.update(beta2 * d + (1 - beta2) * pre_g) # Use pre_g as grad_diff + + pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) + + v.update(beta3 * v + (1 - beta3) * pre_g * pre_g) # Use pre_g as update + + p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) + p.update(p / (1 + lr * weight_decay)) + + pre_g.update(g) # Update pre_g for the next iteration + self.post_step() + +class AdanBelief(Optimizer): + """ Adan Optimizer. + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Adan is an efficient optimizer for most DNN frameworks: + - About 2x fewer computational load than SOTAs + - Robust to training setting and batch size + - Easy to Plug-and-play + + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + """ + def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), + eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): + super().__init__(params, lr) + self.betas = betas + self.eps = eps + self.weight_decay = weight_decay + self.max_grad_norm = max_grad_norm + + for pg in self.param_groups: + pg["m"] = [] + pg["v"] = [] + pg["d"] = [] + pg["pre_grad"] = [] + for p in pg["params"]: + pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + + + def add_param_group(self, group): + group["m"] = [] + group["v"] = [] + group["d"] = [] + group["pre_grad"] = [] + for p in group["params"]: + group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + for pg in self.param_groups: + lr = pg.get("lr", self.lr) + betas = pg.get("betas", self.betas) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) + if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) + beta1, beta2, beta3 = betas + + bias_correction1 = 1 - beta1 ** n + bias_correction2 = 1 - beta2 ** n + bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) + + + step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 + step_size = lr * bias_correction3_sqrt / bias_correction1 + eps_bias_sqrt = eps * bias_correction3_sqrt + + for p, g, m, v, d, pre_g in zip(pg["params"], + pg["grads"], + pg["m"], + pg["v"], + pg["d"], + pg["pre_grad"]): + if p.is_stop_grad(): continue + + if self.n_step>0: + pre_g.update(g - pre_g) # Update pre_g as grad_diff + + + m.update(beta1 * m + (1 - beta1) * g) + d.update(beta2 * d + (1 - beta2) * pre_g) # # Use belief as update + + pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) + + v.update(beta3 * v + (1 - beta3) * (pre_g - m) * (pre_g - m)) # Use pre_g as update + + p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) + p.update(p / (1 + lr * weight_decay)) #AdanBelief best result 0.7358(300 epoch basic) + pre_g.update(g) # Update pre_g for the next iteration + self.post_step() + +class LRScheduler: + def __init__(self,optimizer, last_epoch=-1): + assert isinstance(optimizer,Optimizer) + self.optimizer = optimizer + + if last_epoch==-1: + for gp in optimizer.param_groups: + gp.setdefault('initial_lr',gp.get('lr',optimizer.lr)) + else: + for gp in optimizer.param_groups: + assert 'initial_lr' in gp + + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.last_epoch = last_epoch + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def get_lr(self): + raise NotImplementedError + + def get_last_lr(self): + return self._last_lr + + def step(self,epoch=None): + self._step_count += 1 + + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + self.last_epoch = epoch + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + +class LambdaLR(LRScheduler): + + def __init__(self, optimizer, lr_lambda, last_epoch=-1): + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda))) + + self.lr_lambdas = list(lr_lambda) + + super(LambdaLR, self).__init__(optimizer, last_epoch) + + + + def get_lr(self): + return [base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] \ No newline at end of file