You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_clip.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import jittor as jt
  4. from tqdm import tqdm
  5. from tabulate import tabulate
  6. from tensorboardX import SummaryWriter
  7. from utils.logger import setup_logger
  8. from utils.helper import (EarlyStop, get_optimizer, get_save_path,
  9. accuracy, compute_loss)
  10. from dataset import get_dataloader
  11. from jittor.lr_scheduler import CosineAnnealingLR
  12. from test_clip import zeroshot_classifier
  13. from models import load_clip
  14. jt.flags.use_cuda = 1
  15. class Args:
  16. """设置训练参数"""
  17. def __init__(self):
  18. self.seed = 123 # 随机种子
  19. self.optimizer = 'Adan'
  20. self.lr = 3e-6
  21. self.betas = (0.9, 0.98, 0.98)
  22. self.eps = 1e-8
  23. self.weight_decay = 0.2
  24. self.batch_size = 256
  25. self.num_workers = 8 # 导入数据使用的线程数
  26. self.epochs = 100 # 训练总的轮数
  27. self.early_stop = False
  28. self.patience = 10 # 早停“忍耐”的次数
  29. self.delta = 0.0001 # 早停的阈值
  30. self.caption_version = 2 # 图片描述使用的版本
  31. # 1. a photo of xxx
  32. # 2. 针对数据集自定义的描述
  33. self.data_augment = 0 # 数据增强版本
  34. self.model_save_path = 'ckptFE'
  35. self.log_save_path = 'logs'
  36. self.compute_acc_frequency = 0 # 每隔多少个epoch计算一次训练集acc,0表示不计算
  37. self.use_scheduler = False # 是否使用学习率调度策略
  38. self.save_frequency = 5 # 模型保存频率,每隔多少个epoch进行保存
  39. self.freeze_version = 2
  40. def __str__(self):
  41. # 将参数转换为字典
  42. args_dict = self.__dict__
  43. # 将元组转换为列表,以便tabulate可以正确处理
  44. if isinstance(args_dict['betas'], tuple):
  45. args_dict['betas'] = list(args_dict['betas'])
  46. # 使用tabulate生成表格
  47. table = tabulate([(key, value) for key, value in args_dict.items()],
  48. tablefmt="grid", headers=["Parameter", "Value"])
  49. return table
  50. def train(model, optimizer, train_loader, scheduler, args, log_dir, save_dir, logger):
  51. print('model parameters: \n', args)
  52. logger.info('model parameters: \n')
  53. logger.info(args)
  54. logger.info('\n\nStart training...')
  55. if args.early_stop:
  56. early_stop = EarlyStop(patience=args.patience, delta=args.delta)
  57. writer = SummaryWriter(log_dir=log_dir)
  58. min_loss = float('inf')
  59. acc = None
  60. zeroshot_weights = zeroshot_classifier(model)
  61. pbar = tqdm(range(1, args.epochs+1))
  62. for epoch in pbar:
  63. model.train()
  64. running_loss = 0.0
  65. for images, label, texts in train_loader:
  66. logits_per_image, logits_per_text = model(images, texts)
  67. loss = compute_loss(logits_per_image, logits_per_text)
  68. optimizer.step(loss)
  69. running_loss += loss.item()
  70. if args.compute_acc_frequency != 0 and epoch % args.compute_acc_frequency == 0 :
  71. acc = accuracy(model, train_loader, zeroshot_weights)
  72. if args.use_scheduler:
  73. scheduler.step()
  74. if running_loss < min_loss:
  75. min_loss = running_loss
  76. jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'min_loss')))
  77. if epoch % args.save_frequency == 0:
  78. jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'epoch_{}'.format(epoch))))
  79. if acc is not None:
  80. # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}, Acc: {acc:.4f}")
  81. writer.add_scalar('Train/Acc', acc, epoch)
  82. logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}, Acc: {acc:.4f}")
  83. else:
  84. # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}")
  85. logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}")
  86. writer.add_scalar('Train/Loss', running_loss, epoch)
  87. if args.early_stop and early_stop(running_loss):
  88. logger.info(f'Early stop triggered..., epoch: {epoch}')
  89. # print(f'early stop triggered..., epoch: {epoch}')
  90. break
  91. pbar.close()
  92. logger.info('\n\nFinish training...')
  93. writer.close()
  94. def main(args):
  95. model, transforms = load_clip(freeze_version=args.freeze_version)
  96. # model, transforms = clip.load('pretrained/ViT-B-32.pkl')
  97. train_loader = get_dataloader(transforms, args.batch_size,
  98. args.num_workers, shuffle=True, version=args.caption_version)
  99. optimizer = get_optimizer(args, model)
  100. scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.1)
  101. model_save_path = get_save_path(args.model_save_path, args.optimizer)
  102. log_save_path = get_save_path(args.log_save_path, args.optimizer)
  103. logger = setup_logger(log_save_path)
  104. print(f'Model will be saved at {model_save_path}')
  105. logger.info('Model will be saved at {}'.format(model_save_path))
  106. train(model, optimizer, train_loader, scheduler,
  107. args, log_save_path, model_save_path, logger)
  108. if __name__ == "__main__":
  109. args = Args()
  110. jt.misc.set_global_seed(args.seed)
  111. main(args)

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)