diff --git a/utils/helper.py b/utils/helper.py new file mode 100644 index 0000000..0fe388a --- /dev/null +++ b/utils/helper.py @@ -0,0 +1,164 @@ + +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +import jittor as jt +from jittor import transform +from jittor.optim import Adam, AdamW, SGD, Adan +from PIL import Image +from datetime import datetime +from natsort import natsorted + +class EarlyStop(object): + """早停 + 1. 当模型的损失长时间不下降时,停止训练 + 2. 当模型的损失长时间增大时,也提前停止训练 + """ + def __init__(self, patience=7, delta=0.0001, patience_up=20): + self.patience = patience + self.delta = delta + self.counter = 0 + self.counter_up = 0 + self.last_loss = None + self.early_stop = False + self.patience_up = patience_up + + def __call__(self, loss): + """当输入的loss多次不下降或者上升的时候,返回True,正常时返回False + + Args: + loss (float): 当前的损失值 + + Returns: + bool: 是否早停 + """ + if self.last_loss is None: + self.last_loss = loss + return False + + # loss下降明显低于delta,当前清零 + if loss < self.last_loss - self.delta: + self.counter = 0 + self.counter_up = 0 + self.last_loss = loss + + # loss上升明显高于delta,counter_up开始计数 + elif loss > self.last_loss + self.delta: + self.counter_up += 1 + if self.counter_up >= self.patience_up: + self.early_stop = True + return True + + # loss上升和下降均小于delta,在区间震荡,counter开始计数 + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + return True + + return False + + +def accuracy(model, dataloader, zeroshot_weights): + """计算模型的准确率""" + model.eval() + corrct = 0 + total_count = 0 + with jt.no_grad(): + for i, batch in enumerate(dataloader): + images, targets, texts = batch + total_count += len(images) + image_features = model.encode_image(images) + image_features = image_features / image_features.norm(dim=1, keepdim=True) + logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) + preds = jt.argmax(logits, dim=1)[0] + corrct += jt.equal(preds, targets).sum().item() + return corrct / total_count + +def get_current_date(end_time='day'): + # 获取当前日期时间对象 + current_date = datetime.now() + # 格式化日期为月日时分格式 + if end_time == 'day': + formatted_date = current_date.strftime("%m-%d") + elif end_time == 'minute': + formatted_date = current_date.strftime("%m-%d_%H:%M") + return formatted_date + +def get_save_path(given_path, optimizer): + """获取tensorboard日志/模型保存路径""" + # 文件保存路径如下: + # given_path/date/optimizer/version_x + + path = os.path.join(given_path, get_current_date(end_time='day')) + os.makedirs(path, exist_ok=True) + + try: + last_version = int(natsorted(os.listdir(path))[-1].split('_')[-1]) + current_path = os.path.join(path, f'version_{last_version + 1}') + os.makedirs(current_path, exist_ok=True) + except IndexError: + current_path = os.path.join(path, 'version_0') + os.makedirs(current_path, exist_ok=True) + return current_path + + + + + +def get_optimizer(args, model): + """根据输入参数获取优化器""" + if args.optimizer == 'Adam': + optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, + betas=args.betas, eps=args.eps) + + elif args.optimizer == 'AdamW': + optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, + betas=args.betas, eps=args.eps) + + elif args.optimizer == 'Adan': + if len(args.betas) == 2: + raise ValueError('Adan optimizer requires betas has the shape like (0.9,0.98, 0.99)') + optimizer = Adan(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, + betas=args.betas, eps=args.eps) + + elif args.optimizer == 'SGD': + optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, + momentum=0.9) + + else: + raise ValueError('Unsupported optimizer, please check the optimizer name.') + + return optimizer + +def get_scheduler(optimizer, args): + """根据输入参数获取学习率调度器""" + pass + +def get_transform(args): + """根据输入参数获取数据预处理""" + if args.data_preprocess == 1: + transforms = transform.Compose([ + transform.Resize(224, mode=Image.BICUBIC), + transform.CenterCrop(224), lambda image: image.convert("RGB"), + transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)) + ]) + return transforms + elif args.data_preprocess == 2: + transforms = transform.Compose([ + transform.Resize(224, mode=Image.BICUBIC), + transform.CenterCrop(224), lambda image: image.convert("RGB"), + transform.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.4, hue=0.1), + transform.RandomRotation(10), + transform.RandomHorizontalFlip(), + transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711))]) + + +def compute_loss(logits_image, logits_text): + """计算损失函数,用来建立文本与图像的语义关系,实现语义对其""" + ground_truth = jt.arange(len(logits_image), dtype=jt.int32) + loss = (jt.nn.cross_entropy_loss(logits_image, ground_truth) +\ + jt.nn.cross_entropy_loss(logits_text, ground_truth)) / 2 + return loss \ No newline at end of file