|
|
@@ -0,0 +1,149 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
|
|
import jittor as jt |
|
|
|
from tqdm import tqdm |
|
|
|
from tabulate import tabulate |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
from utils.logger import setup_logger |
|
|
|
from utils.helper import (EarlyStop, get_optimizer, get_save_path, |
|
|
|
accuracy, compute_loss) |
|
|
|
from dataset import get_dataloader |
|
|
|
from jittor.lr_scheduler import CosineAnnealingLR |
|
|
|
from test_clip import zeroshot_classifier |
|
|
|
from models import load_clip |
|
|
|
|
|
|
|
jt.flags.use_cuda = 1 |
|
|
|
|
|
|
|
|
|
|
|
class Args: |
|
|
|
"""设置训练参数""" |
|
|
|
def __init__(self): |
|
|
|
self.seed = 123 # 随机种子 |
|
|
|
self.optimizer = 'Adan' |
|
|
|
self.lr = 3e-6 |
|
|
|
self.betas = (0.9, 0.98, 0.98) |
|
|
|
self.eps = 1e-8 |
|
|
|
self.weight_decay = 0.2 |
|
|
|
self.batch_size = 256 |
|
|
|
self.num_workers = 8 # 导入数据使用的线程数 |
|
|
|
|
|
|
|
|
|
|
|
self.epochs = 100 # 训练总的轮数 |
|
|
|
|
|
|
|
self.early_stop = False |
|
|
|
self.patience = 10 # 早停“忍耐”的次数 |
|
|
|
self.delta = 0.0001 # 早停的阈值 |
|
|
|
self.caption_version = 2 # 图片描述使用的版本 |
|
|
|
# 1. a photo of xxx |
|
|
|
# 2. 针对数据集自定义的描述 |
|
|
|
|
|
|
|
self.data_augment = 0 # 数据增强版本 |
|
|
|
self.model_save_path = 'ckptFE' |
|
|
|
self.log_save_path = 'logs' |
|
|
|
|
|
|
|
self.compute_acc_frequency = 0 # 每隔多少个epoch计算一次训练集acc,0表示不计算 |
|
|
|
self.use_scheduler = False # 是否使用学习率调度策略 |
|
|
|
|
|
|
|
self.save_frequency = 5 # 模型保存频率,每隔多少个epoch进行保存 |
|
|
|
|
|
|
|
self.freeze_version = 2 |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
# 将参数转换为字典 |
|
|
|
args_dict = self.__dict__ |
|
|
|
# 将元组转换为列表,以便tabulate可以正确处理 |
|
|
|
if isinstance(args_dict['betas'], tuple): |
|
|
|
args_dict['betas'] = list(args_dict['betas']) |
|
|
|
# 使用tabulate生成表格 |
|
|
|
table = tabulate([(key, value) for key, value in args_dict.items()], |
|
|
|
tablefmt="grid", headers=["Parameter", "Value"]) |
|
|
|
return table |
|
|
|
|
|
|
|
def train(model, optimizer, train_loader, scheduler, args, log_dir, save_dir, logger): |
|
|
|
|
|
|
|
print('model parameters: \n', args) |
|
|
|
logger.info('model parameters: \n') |
|
|
|
logger.info(args) |
|
|
|
|
|
|
|
logger.info('\n\nStart training...') |
|
|
|
|
|
|
|
if args.early_stop: |
|
|
|
early_stop = EarlyStop(patience=args.patience, delta=args.delta) |
|
|
|
writer = SummaryWriter(log_dir=log_dir) |
|
|
|
|
|
|
|
min_loss = float('inf') |
|
|
|
acc = None |
|
|
|
zeroshot_weights = zeroshot_classifier(model) |
|
|
|
|
|
|
|
pbar = tqdm(range(1, args.epochs+1)) |
|
|
|
for epoch in pbar: |
|
|
|
model.train() |
|
|
|
running_loss = 0.0 |
|
|
|
for images, label, texts in train_loader: |
|
|
|
logits_per_image, logits_per_text = model(images, texts) |
|
|
|
loss = compute_loss(logits_per_image, logits_per_text) |
|
|
|
optimizer.step(loss) |
|
|
|
running_loss += loss.item() |
|
|
|
|
|
|
|
if args.compute_acc_frequency != 0 and epoch % args.compute_acc_frequency == 0 : |
|
|
|
acc = accuracy(model, train_loader, zeroshot_weights) |
|
|
|
|
|
|
|
if args.use_scheduler: |
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
|
|
|
if running_loss < min_loss: |
|
|
|
min_loss = running_loss |
|
|
|
jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'min_loss'))) |
|
|
|
|
|
|
|
if epoch % args.save_frequency == 0: |
|
|
|
jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'epoch_{}'.format(epoch)))) |
|
|
|
|
|
|
|
if acc is not None: |
|
|
|
# pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") |
|
|
|
writer.add_scalar('Train/Acc', acc, epoch) |
|
|
|
logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") |
|
|
|
else: |
|
|
|
# pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}") |
|
|
|
logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}") |
|
|
|
writer.add_scalar('Train/Loss', running_loss, epoch) |
|
|
|
|
|
|
|
if args.early_stop and early_stop(running_loss): |
|
|
|
logger.info(f'Early stop triggered..., epoch: {epoch}') |
|
|
|
# print(f'early stop triggered..., epoch: {epoch}') |
|
|
|
break |
|
|
|
|
|
|
|
pbar.close() |
|
|
|
logger.info('\n\nFinish training...') |
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
model, transforms = load_clip(freeze_version=args.freeze_version) |
|
|
|
# model, transforms = clip.load('pretrained/ViT-B-32.pkl') |
|
|
|
train_loader = get_dataloader(transforms, args.batch_size, |
|
|
|
args.num_workers, shuffle=True, version=args.caption_version) |
|
|
|
|
|
|
|
optimizer = get_optimizer(args, model) |
|
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.1) |
|
|
|
|
|
|
|
model_save_path = get_save_path(args.model_save_path, args.optimizer) |
|
|
|
log_save_path = get_save_path(args.log_save_path, args.optimizer) |
|
|
|
logger = setup_logger(log_save_path) |
|
|
|
|
|
|
|
print(f'Model will be saved at {model_save_path}') |
|
|
|
logger.info('Model will be saved at {}'.format(model_save_path)) |
|
|
|
|
|
|
|
train(model, optimizer, train_loader, scheduler, |
|
|
|
args, log_save_path, model_save_path, logger) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = Args() |
|
|
|
jt.misc.set_global_seed(args.seed) |
|
|
|
main(args) |
|
|
|
|
|
|
|
|