| @@ -0,0 +1,164 @@ | |||
| import os | |||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||
| import jittor as jt | |||
| from jittor import transform | |||
| from jittor.optim import Adam, AdamW, SGD, Adan | |||
| from PIL import Image | |||
| from datetime import datetime | |||
| from natsort import natsorted | |||
| class EarlyStop(object): | |||
| """早停 | |||
| 1. 当模型的损失长时间不下降时,停止训练 | |||
| 2. 当模型的损失长时间增大时,也提前停止训练 | |||
| """ | |||
| def __init__(self, patience=7, delta=0.0001, patience_up=20): | |||
| self.patience = patience | |||
| self.delta = delta | |||
| self.counter = 0 | |||
| self.counter_up = 0 | |||
| self.last_loss = None | |||
| self.early_stop = False | |||
| self.patience_up = patience_up | |||
| def __call__(self, loss): | |||
| """当输入的loss多次不下降或者上升的时候,返回True,正常时返回False | |||
| Args: | |||
| loss (float): 当前的损失值 | |||
| Returns: | |||
| bool: 是否早停 | |||
| """ | |||
| if self.last_loss is None: | |||
| self.last_loss = loss | |||
| return False | |||
| # loss下降明显低于delta,当前清零 | |||
| if loss < self.last_loss - self.delta: | |||
| self.counter = 0 | |||
| self.counter_up = 0 | |||
| self.last_loss = loss | |||
| # loss上升明显高于delta,counter_up开始计数 | |||
| elif loss > self.last_loss + self.delta: | |||
| self.counter_up += 1 | |||
| if self.counter_up >= self.patience_up: | |||
| self.early_stop = True | |||
| return True | |||
| # loss上升和下降均小于delta,在区间震荡,counter开始计数 | |||
| else: | |||
| self.counter += 1 | |||
| if self.counter >= self.patience: | |||
| self.early_stop = True | |||
| return True | |||
| return False | |||
| def accuracy(model, dataloader, zeroshot_weights): | |||
| """计算模型的准确率""" | |||
| model.eval() | |||
| corrct = 0 | |||
| total_count = 0 | |||
| with jt.no_grad(): | |||
| for i, batch in enumerate(dataloader): | |||
| images, targets, texts = batch | |||
| total_count += len(images) | |||
| image_features = model.encode_image(images) | |||
| image_features = image_features / image_features.norm(dim=1, keepdim=True) | |||
| logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) | |||
| preds = jt.argmax(logits, dim=1)[0] | |||
| corrct += jt.equal(preds, targets).sum().item() | |||
| return corrct / total_count | |||
| def get_current_date(end_time='day'): | |||
| # 获取当前日期时间对象 | |||
| current_date = datetime.now() | |||
| # 格式化日期为月日时分格式 | |||
| if end_time == 'day': | |||
| formatted_date = current_date.strftime("%m-%d") | |||
| elif end_time == 'minute': | |||
| formatted_date = current_date.strftime("%m-%d_%H:%M") | |||
| return formatted_date | |||
| def get_save_path(given_path, optimizer): | |||
| """获取tensorboard日志/模型保存路径""" | |||
| # 文件保存路径如下: | |||
| # given_path/date/optimizer/version_x | |||
| path = os.path.join(given_path, get_current_date(end_time='day')) | |||
| os.makedirs(path, exist_ok=True) | |||
| try: | |||
| last_version = int(natsorted(os.listdir(path))[-1].split('_')[-1]) | |||
| current_path = os.path.join(path, f'version_{last_version + 1}') | |||
| os.makedirs(current_path, exist_ok=True) | |||
| except IndexError: | |||
| current_path = os.path.join(path, 'version_0') | |||
| os.makedirs(current_path, exist_ok=True) | |||
| return current_path | |||
| def get_optimizer(args, model): | |||
| """根据输入参数获取优化器""" | |||
| if args.optimizer == 'Adam': | |||
| optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, | |||
| betas=args.betas, eps=args.eps) | |||
| elif args.optimizer == 'AdamW': | |||
| optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, | |||
| betas=args.betas, eps=args.eps) | |||
| elif args.optimizer == 'Adan': | |||
| if len(args.betas) == 2: | |||
| raise ValueError('Adan optimizer requires betas has the shape like (0.9,0.98, 0.99)') | |||
| optimizer = Adan(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, | |||
| betas=args.betas, eps=args.eps) | |||
| elif args.optimizer == 'SGD': | |||
| optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, | |||
| momentum=0.9) | |||
| else: | |||
| raise ValueError('Unsupported optimizer, please check the optimizer name.') | |||
| return optimizer | |||
| def get_scheduler(optimizer, args): | |||
| """根据输入参数获取学习率调度器""" | |||
| pass | |||
| def get_transform(args): | |||
| """根据输入参数获取数据预处理""" | |||
| if args.data_preprocess == 1: | |||
| transforms = transform.Compose([ | |||
| transform.Resize(224, mode=Image.BICUBIC), | |||
| transform.CenterCrop(224), lambda image: image.convert("RGB"), | |||
| transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073), | |||
| std=(0.26862954, 0.26130258, 0.27577711)) | |||
| ]) | |||
| return transforms | |||
| elif args.data_preprocess == 2: | |||
| transforms = transform.Compose([ | |||
| transform.Resize(224, mode=Image.BICUBIC), | |||
| transform.CenterCrop(224), lambda image: image.convert("RGB"), | |||
| transform.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.4, hue=0.1), | |||
| transform.RandomRotation(10), | |||
| transform.RandomHorizontalFlip(), | |||
| transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073), | |||
| std=(0.26862954, 0.26130258, 0.27577711))]) | |||
| def compute_loss(logits_image, logits_text): | |||
| """计算损失函数,用来建立文本与图像的语义关系,实现语义对其""" | |||
| ground_truth = jt.arange(len(logits_image), dtype=jt.int32) | |||
| loss = (jt.nn.cross_entropy_loss(logits_image, ground_truth) +\ | |||
| jt.nn.cross_entropy_loss(logits_text, ground_truth)) / 2 | |||
| return loss | |||