From 3f0cc9c5406e737d9859bab13f7cd9b2c5873a88 Mon Sep 17 00:00:00 2001 From: BIT2024 Date: Tue, 20 Aug 2024 14:56:20 +0800 Subject: [PATCH] ADD file via upload --- train_clip.py | 149 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 train_clip.py diff --git a/train_clip.py b/train_clip.py new file mode 100644 index 0000000..74c5239 --- /dev/null +++ b/train_clip.py @@ -0,0 +1,149 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +import jittor as jt +from tqdm import tqdm +from tabulate import tabulate +from tensorboardX import SummaryWriter +from utils.logger import setup_logger +from utils.helper import (EarlyStop, get_optimizer, get_save_path, + accuracy, compute_loss) +from dataset import get_dataloader +from jittor.lr_scheduler import CosineAnnealingLR +from test_clip import zeroshot_classifier +from models import load_clip + +jt.flags.use_cuda = 1 + + +class Args: + """设置训练参数""" + def __init__(self): + self.seed = 123 # 随机种子 + self.optimizer = 'Adan' + self.lr = 3e-6 + self.betas = (0.9, 0.98, 0.98) + self.eps = 1e-8 + self.weight_decay = 0.2 + self.batch_size = 256 + self.num_workers = 8 # 导入数据使用的线程数 + + + self.epochs = 100 # 训练总的轮数 + + self.early_stop = False + self.patience = 10 # 早停“忍耐”的次数 + self.delta = 0.0001 # 早停的阈值 + self.caption_version = 2 # 图片描述使用的版本 + # 1. a photo of xxx + # 2. 针对数据集自定义的描述 + + self.data_augment = 0 # 数据增强版本 + self.model_save_path = 'ckptFE' + self.log_save_path = 'logs' + + self.compute_acc_frequency = 0 # 每隔多少个epoch计算一次训练集acc,0表示不计算 + self.use_scheduler = False # 是否使用学习率调度策略 + + self.save_frequency = 5 # 模型保存频率,每隔多少个epoch进行保存 + + self.freeze_version = 2 + + def __str__(self): + # 将参数转换为字典 + args_dict = self.__dict__ + # 将元组转换为列表,以便tabulate可以正确处理 + if isinstance(args_dict['betas'], tuple): + args_dict['betas'] = list(args_dict['betas']) + # 使用tabulate生成表格 + table = tabulate([(key, value) for key, value in args_dict.items()], + tablefmt="grid", headers=["Parameter", "Value"]) + return table + +def train(model, optimizer, train_loader, scheduler, args, log_dir, save_dir, logger): + + print('model parameters: \n', args) + logger.info('model parameters: \n') + logger.info(args) + + logger.info('\n\nStart training...') + + if args.early_stop: + early_stop = EarlyStop(patience=args.patience, delta=args.delta) + writer = SummaryWriter(log_dir=log_dir) + + min_loss = float('inf') + acc = None + zeroshot_weights = zeroshot_classifier(model) + + pbar = tqdm(range(1, args.epochs+1)) + for epoch in pbar: + model.train() + running_loss = 0.0 + for images, label, texts in train_loader: + logits_per_image, logits_per_text = model(images, texts) + loss = compute_loss(logits_per_image, logits_per_text) + optimizer.step(loss) + running_loss += loss.item() + + if args.compute_acc_frequency != 0 and epoch % args.compute_acc_frequency == 0 : + acc = accuracy(model, train_loader, zeroshot_weights) + + if args.use_scheduler: + scheduler.step() + + + if running_loss < min_loss: + min_loss = running_loss + jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'min_loss'))) + + if epoch % args.save_frequency == 0: + jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'epoch_{}'.format(epoch)))) + + if acc is not None: + # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") + writer.add_scalar('Train/Acc', acc, epoch) + logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") + else: + # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}") + logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}") + writer.add_scalar('Train/Loss', running_loss, epoch) + + if args.early_stop and early_stop(running_loss): + logger.info(f'Early stop triggered..., epoch: {epoch}') + # print(f'early stop triggered..., epoch: {epoch}') + break + + pbar.close() + logger.info('\n\nFinish training...') + writer.close() + + +def main(args): + model, transforms = load_clip(freeze_version=args.freeze_version) + # model, transforms = clip.load('pretrained/ViT-B-32.pkl') + train_loader = get_dataloader(transforms, args.batch_size, + args.num_workers, shuffle=True, version=args.caption_version) + + optimizer = get_optimizer(args, model) + scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.1) + + model_save_path = get_save_path(args.model_save_path, args.optimizer) + log_save_path = get_save_path(args.log_save_path, args.optimizer) + logger = setup_logger(log_save_path) + + print(f'Model will be saved at {model_save_path}') + logger.info('Model will be saved at {}'.format(model_save_path)) + + train(model, optimizer, train_loader, scheduler, + args, log_save_path, model_save_path, logger) + + + +if __name__ == "__main__": + args = Args() + jt.misc.set_global_seed(args.seed) + main(args) + +