import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' import jittor as jt from tqdm import tqdm from tabulate import tabulate from tensorboardX import SummaryWriter from utils.logger import setup_logger from utils.helper import (EarlyStop, get_optimizer, get_save_path, accuracy, compute_loss) from dataset import get_dataloader from jittor.lr_scheduler import CosineAnnealingLR from test_clip import zeroshot_classifier from models import load_clip jt.flags.use_cuda = 1 class Args: """设置训练参数""" def __init__(self): self.seed = 123 # 随机种子 self.optimizer = 'Adan' self.lr = 3e-6 self.betas = (0.9, 0.98, 0.98) self.eps = 1e-8 self.weight_decay = 0.2 self.batch_size = 256 self.num_workers = 8 # 导入数据使用的线程数 self.epochs = 100 # 训练总的轮数 self.early_stop = False self.patience = 10 # 早停“忍耐”的次数 self.delta = 0.0001 # 早停的阈值 self.caption_version = 2 # 图片描述使用的版本 # 1. a photo of xxx # 2. 针对数据集自定义的描述 self.data_augment = 0 # 数据增强版本 self.model_save_path = 'ckptFE' self.log_save_path = 'logs' self.compute_acc_frequency = 0 # 每隔多少个epoch计算一次训练集acc,0表示不计算 self.use_scheduler = False # 是否使用学习率调度策略 self.save_frequency = 5 # 模型保存频率,每隔多少个epoch进行保存 self.freeze_version = 2 def __str__(self): # 将参数转换为字典 args_dict = self.__dict__ # 将元组转换为列表,以便tabulate可以正确处理 if isinstance(args_dict['betas'], tuple): args_dict['betas'] = list(args_dict['betas']) # 使用tabulate生成表格 table = tabulate([(key, value) for key, value in args_dict.items()], tablefmt="grid", headers=["Parameter", "Value"]) return table def train(model, optimizer, train_loader, scheduler, args, log_dir, save_dir, logger): print('model parameters: \n', args) logger.info('model parameters: \n') logger.info(args) logger.info('\n\nStart training...') if args.early_stop: early_stop = EarlyStop(patience=args.patience, delta=args.delta) writer = SummaryWriter(log_dir=log_dir) min_loss = float('inf') acc = None zeroshot_weights = zeroshot_classifier(model) pbar = tqdm(range(1, args.epochs+1)) for epoch in pbar: model.train() running_loss = 0.0 for images, label, texts in train_loader: logits_per_image, logits_per_text = model(images, texts) loss = compute_loss(logits_per_image, logits_per_text) optimizer.step(loss) running_loss += loss.item() if args.compute_acc_frequency != 0 and epoch % args.compute_acc_frequency == 0 : acc = accuracy(model, train_loader, zeroshot_weights) if args.use_scheduler: scheduler.step() if running_loss < min_loss: min_loss = running_loss jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'min_loss'))) if epoch % args.save_frequency == 0: jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'epoch_{}'.format(epoch)))) if acc is not None: # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") writer.add_scalar('Train/Acc', acc, epoch) logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") else: # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}") logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}") writer.add_scalar('Train/Loss', running_loss, epoch) if args.early_stop and early_stop(running_loss): logger.info(f'Early stop triggered..., epoch: {epoch}') # print(f'early stop triggered..., epoch: {epoch}') break pbar.close() logger.info('\n\nFinish training...') writer.close() def main(args): model, transforms = load_clip(freeze_version=args.freeze_version) # model, transforms = clip.load('pretrained/ViT-B-32.pkl') train_loader = get_dataloader(transforms, args.batch_size, args.num_workers, shuffle=True, version=args.caption_version) optimizer = get_optimizer(args, model) scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.1) model_save_path = get_save_path(args.model_save_path, args.optimizer) log_save_path = get_save_path(args.log_save_path, args.optimizer) logger = setup_logger(log_save_path) print(f'Model will be saved at {model_save_path}') logger.info('Model will be saved at {}'.format(model_save_path)) train(model, optimizer, train_loader, scheduler, args, log_save_path, model_save_path, logger) if __name__ == "__main__": args = Args() jt.misc.set_global_seed(args.seed) main(args)