Browse Source

ADD file via upload

master
BIT2024 1 year ago
parent
commit
3f0cc9c540
1 changed files with 149 additions and 0 deletions
  1. +149
    -0
      train_clip.py

+ 149
- 0
train_clip.py View File

@@ -0,0 +1,149 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import jittor as jt
from tqdm import tqdm
from tabulate import tabulate
from tensorboardX import SummaryWriter
from utils.logger import setup_logger
from utils.helper import (EarlyStop, get_optimizer, get_save_path,
accuracy, compute_loss)
from dataset import get_dataloader
from jittor.lr_scheduler import CosineAnnealingLR
from test_clip import zeroshot_classifier
from models import load_clip

jt.flags.use_cuda = 1


class Args:
"""设置训练参数"""
def __init__(self):
self.seed = 123 # 随机种子
self.optimizer = 'Adan'
self.lr = 3e-6
self.betas = (0.9, 0.98, 0.98)
self.eps = 1e-8
self.weight_decay = 0.2
self.batch_size = 256
self.num_workers = 8 # 导入数据使用的线程数
self.epochs = 100 # 训练总的轮数
self.early_stop = False
self.patience = 10 # 早停“忍耐”的次数
self.delta = 0.0001 # 早停的阈值
self.caption_version = 2 # 图片描述使用的版本
# 1. a photo of xxx
# 2. 针对数据集自定义的描述
self.data_augment = 0 # 数据增强版本
self.model_save_path = 'ckptFE'
self.log_save_path = 'logs'
self.compute_acc_frequency = 0 # 每隔多少个epoch计算一次训练集acc,0表示不计算
self.use_scheduler = False # 是否使用学习率调度策略
self.save_frequency = 5 # 模型保存频率,每隔多少个epoch进行保存
self.freeze_version = 2

def __str__(self):
# 将参数转换为字典
args_dict = self.__dict__
# 将元组转换为列表,以便tabulate可以正确处理
if isinstance(args_dict['betas'], tuple):
args_dict['betas'] = list(args_dict['betas'])
# 使用tabulate生成表格
table = tabulate([(key, value) for key, value in args_dict.items()],
tablefmt="grid", headers=["Parameter", "Value"])
return table

def train(model, optimizer, train_loader, scheduler, args, log_dir, save_dir, logger):
print('model parameters: \n', args)
logger.info('model parameters: \n')
logger.info(args)
logger.info('\n\nStart training...')
if args.early_stop:
early_stop = EarlyStop(patience=args.patience, delta=args.delta)
writer = SummaryWriter(log_dir=log_dir)
min_loss = float('inf')
acc = None
zeroshot_weights = zeroshot_classifier(model)

pbar = tqdm(range(1, args.epochs+1))
for epoch in pbar:
model.train()
running_loss = 0.0
for images, label, texts in train_loader:
logits_per_image, logits_per_text = model(images, texts)
loss = compute_loss(logits_per_image, logits_per_text)
optimizer.step(loss)
running_loss += loss.item()
if args.compute_acc_frequency != 0 and epoch % args.compute_acc_frequency == 0 :
acc = accuracy(model, train_loader, zeroshot_weights)

if args.use_scheduler:
scheduler.step()
if running_loss < min_loss:
min_loss = running_loss
jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'min_loss')))
if epoch % args.save_frequency == 0:
jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'epoch_{}'.format(epoch))))
if acc is not None:
# pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}, Acc: {acc:.4f}")
writer.add_scalar('Train/Acc', acc, epoch)
logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}, Acc: {acc:.4f}")
else:
# pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}")
logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}")
writer.add_scalar('Train/Loss', running_loss, epoch)
if args.early_stop and early_stop(running_loss):
logger.info(f'Early stop triggered..., epoch: {epoch}')
# print(f'early stop triggered..., epoch: {epoch}')
break
pbar.close()
logger.info('\n\nFinish training...')
writer.close()


def main(args):
model, transforms = load_clip(freeze_version=args.freeze_version)
# model, transforms = clip.load('pretrained/ViT-B-32.pkl')
train_loader = get_dataloader(transforms, args.batch_size,
args.num_workers, shuffle=True, version=args.caption_version)
optimizer = get_optimizer(args, model)
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.1)

model_save_path = get_save_path(args.model_save_path, args.optimizer)
log_save_path = get_save_path(args.log_save_path, args.optimizer)
logger = setup_logger(log_save_path)
print(f'Model will be saved at {model_save_path}')
logger.info('Model will be saved at {}'.format(model_save_path))
train(model, optimizer, train_loader, scheduler,
args, log_save_path, model_save_path, logger)

if __name__ == "__main__":
args = Args()
jt.misc.set_global_seed(args.seed)
main(args)

Loading…
Cancel
Save