From e2d58aa13c44bf3fd32f17a3a046406779040cfb Mon Sep 17 00:00:00 2001 From: yuanjunbin Date: Thu, 5 Oct 2023 17:12:18 +0800 Subject: [PATCH] update --- push.sh | 3 ++ push.sh.bak | 3 ++ stock-forcast-models/utils/tools.py | 77 +++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 push.sh create mode 100644 push.sh.bak create mode 100644 stock-forcast-models/utils/tools.py diff --git a/push.sh b/push.sh new file mode 100644 index 0000000..8b7a89a --- /dev/null +++ b/push.sh @@ -0,0 +1,3 @@ +git add . +git commit -a -m "update" +git push \ No newline at end of file diff --git a/push.sh.bak b/push.sh.bak new file mode 100644 index 0000000..6026404 --- /dev/null +++ b/push.sh.bak @@ -0,0 +1,3 @@ +git add. +git commit -a -m "update" +git push \ No newline at end of file diff --git a/stock-forcast-models/utils/tools.py b/stock-forcast-models/utils/tools.py new file mode 100644 index 0000000..3930442 --- /dev/null +++ b/stock-forcast-models/utils/tools.py @@ -0,0 +1,77 @@ +import numpy as np +import mindspore.numpy as mnp +import mindspore +from mindspore import Tensor, Parameter + +def adjust_learning_rate(optimizer, epoch, args): + if args.lradj == 'type1': + lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch-1) // 1))} + elif args.lradj == 'type2': + lr_adjust = { + 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, + 10: 5e-7, 15: 1e-7, 20: 5e-8 + } + if epoch in lr_adjust.keys(): + lr = lr_adjust[epoch] + for param_group in optimizer.parameters(): + param_group.set_lr(lr) + print('Updating learning rate to {}'.format(lr)) + +class EarlyStopping: + def __init__(self, patience=7, verbose=False, delta=0): + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + + def __call__(self, val_loss, model, path): + score = -val_loss + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model, path) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model, path) + self.counter = 0 + + def save_checkpoint(self, val_loss, model, path): + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + model.save_checkpoint(path + '/' + 'checkpoint.ckpt') + self.val_loss_min = val_loss + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + +class StandardScaler(): + def __init__(self): + self.mean = 0. + self.std = 1. + + def fit(self, data): + self.mean = mnp.mean(data, 0) + self.std = mnp.std(data, 0) + + def transform(self, data): + mean = Tensor(self.mean, mindspore.float32) + std = Tensor(self.std, mindspore.float32) + return (data - mean) / std + + def inverse_transform(self, data): + mean = Tensor(self.mean, mindspore.float32) + std = Tensor(self.std, mindspore.float32) + if data.shape[-1] != mean.shape[-1]: + mean = mean[-1:] + std = std[-1:] + return (data * std) + mean \ No newline at end of file