|
|
@@ -0,0 +1,164 @@ |
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
import jittor as jt |
|
|
|
from jittor import transform |
|
|
|
from jittor.optim import Adam, AdamW, SGD, Adan |
|
|
|
from PIL import Image |
|
|
|
from datetime import datetime |
|
|
|
from natsort import natsorted |
|
|
|
|
|
|
|
class EarlyStop(object): |
|
|
|
"""早停 |
|
|
|
1. 当模型的损失长时间不下降时,停止训练 |
|
|
|
2. 当模型的损失长时间增大时,也提前停止训练 |
|
|
|
""" |
|
|
|
def __init__(self, patience=7, delta=0.0001, patience_up=20): |
|
|
|
self.patience = patience |
|
|
|
self.delta = delta |
|
|
|
self.counter = 0 |
|
|
|
self.counter_up = 0 |
|
|
|
self.last_loss = None |
|
|
|
self.early_stop = False |
|
|
|
self.patience_up = patience_up |
|
|
|
|
|
|
|
def __call__(self, loss): |
|
|
|
"""当输入的loss多次不下降或者上升的时候,返回True,正常时返回False |
|
|
|
|
|
|
|
Args: |
|
|
|
loss (float): 当前的损失值 |
|
|
|
|
|
|
|
Returns: |
|
|
|
bool: 是否早停 |
|
|
|
""" |
|
|
|
if self.last_loss is None: |
|
|
|
self.last_loss = loss |
|
|
|
return False |
|
|
|
|
|
|
|
# loss下降明显低于delta,当前清零 |
|
|
|
if loss < self.last_loss - self.delta: |
|
|
|
self.counter = 0 |
|
|
|
self.counter_up = 0 |
|
|
|
self.last_loss = loss |
|
|
|
|
|
|
|
# loss上升明显高于delta,counter_up开始计数 |
|
|
|
elif loss > self.last_loss + self.delta: |
|
|
|
self.counter_up += 1 |
|
|
|
if self.counter_up >= self.patience_up: |
|
|
|
self.early_stop = True |
|
|
|
return True |
|
|
|
|
|
|
|
# loss上升和下降均小于delta,在区间震荡,counter开始计数 |
|
|
|
else: |
|
|
|
self.counter += 1 |
|
|
|
if self.counter >= self.patience: |
|
|
|
self.early_stop = True |
|
|
|
return True |
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def accuracy(model, dataloader, zeroshot_weights): |
|
|
|
"""计算模型的准确率""" |
|
|
|
model.eval() |
|
|
|
corrct = 0 |
|
|
|
total_count = 0 |
|
|
|
with jt.no_grad(): |
|
|
|
for i, batch in enumerate(dataloader): |
|
|
|
images, targets, texts = batch |
|
|
|
total_count += len(images) |
|
|
|
image_features = model.encode_image(images) |
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
|
|
logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) |
|
|
|
preds = jt.argmax(logits, dim=1)[0] |
|
|
|
corrct += jt.equal(preds, targets).sum().item() |
|
|
|
return corrct / total_count |
|
|
|
|
|
|
|
def get_current_date(end_time='day'): |
|
|
|
# 获取当前日期时间对象 |
|
|
|
current_date = datetime.now() |
|
|
|
# 格式化日期为月日时分格式 |
|
|
|
if end_time == 'day': |
|
|
|
formatted_date = current_date.strftime("%m-%d") |
|
|
|
elif end_time == 'minute': |
|
|
|
formatted_date = current_date.strftime("%m-%d_%H:%M") |
|
|
|
return formatted_date |
|
|
|
|
|
|
|
def get_save_path(given_path, optimizer): |
|
|
|
"""获取tensorboard日志/模型保存路径""" |
|
|
|
# 文件保存路径如下: |
|
|
|
# given_path/date/optimizer/version_x |
|
|
|
|
|
|
|
path = os.path.join(given_path, get_current_date(end_time='day')) |
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
|
|
try: |
|
|
|
last_version = int(natsorted(os.listdir(path))[-1].split('_')[-1]) |
|
|
|
current_path = os.path.join(path, f'version_{last_version + 1}') |
|
|
|
os.makedirs(current_path, exist_ok=True) |
|
|
|
except IndexError: |
|
|
|
current_path = os.path.join(path, 'version_0') |
|
|
|
os.makedirs(current_path, exist_ok=True) |
|
|
|
return current_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_optimizer(args, model): |
|
|
|
"""根据输入参数获取优化器""" |
|
|
|
if args.optimizer == 'Adam': |
|
|
|
optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, |
|
|
|
betas=args.betas, eps=args.eps) |
|
|
|
|
|
|
|
elif args.optimizer == 'AdamW': |
|
|
|
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, |
|
|
|
betas=args.betas, eps=args.eps) |
|
|
|
|
|
|
|
elif args.optimizer == 'Adan': |
|
|
|
if len(args.betas) == 2: |
|
|
|
raise ValueError('Adan optimizer requires betas has the shape like (0.9,0.98, 0.99)') |
|
|
|
optimizer = Adan(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, |
|
|
|
betas=args.betas, eps=args.eps) |
|
|
|
|
|
|
|
elif args.optimizer == 'SGD': |
|
|
|
optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, |
|
|
|
momentum=0.9) |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError('Unsupported optimizer, please check the optimizer name.') |
|
|
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
def get_scheduler(optimizer, args): |
|
|
|
"""根据输入参数获取学习率调度器""" |
|
|
|
pass |
|
|
|
|
|
|
|
def get_transform(args): |
|
|
|
"""根据输入参数获取数据预处理""" |
|
|
|
if args.data_preprocess == 1: |
|
|
|
transforms = transform.Compose([ |
|
|
|
transform.Resize(224, mode=Image.BICUBIC), |
|
|
|
transform.CenterCrop(224), lambda image: image.convert("RGB"), |
|
|
|
transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073), |
|
|
|
std=(0.26862954, 0.26130258, 0.27577711)) |
|
|
|
]) |
|
|
|
return transforms |
|
|
|
elif args.data_preprocess == 2: |
|
|
|
transforms = transform.Compose([ |
|
|
|
transform.Resize(224, mode=Image.BICUBIC), |
|
|
|
transform.CenterCrop(224), lambda image: image.convert("RGB"), |
|
|
|
transform.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.4, hue=0.1), |
|
|
|
transform.RandomRotation(10), |
|
|
|
transform.RandomHorizontalFlip(), |
|
|
|
transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073), |
|
|
|
std=(0.26862954, 0.26130258, 0.27577711))]) |
|
|
|
|
|
|
|
|
|
|
|
def compute_loss(logits_image, logits_text): |
|
|
|
"""计算损失函数,用来建立文本与图像的语义关系,实现语义对其""" |
|
|
|
ground_truth = jt.arange(len(logits_image), dtype=jt.int32) |
|
|
|
loss = (jt.nn.cross_entropy_loss(logits_image, ground_truth) +\ |
|
|
|
jt.nn.cross_entropy_loss(logits_text, ground_truth)) / 2 |
|
|
|
return loss |