You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import jittor as jt
  4. from jittor import transform
  5. from jittor.optim import Adam, AdamW, SGD, Adan
  6. from PIL import Image
  7. from datetime import datetime
  8. from natsort import natsorted
  9. class EarlyStop(object):
  10. """早停
  11. 1. 当模型的损失长时间不下降时,停止训练
  12. 2. 当模型的损失长时间增大时,也提前停止训练
  13. """
  14. def __init__(self, patience=7, delta=0.0001, patience_up=20):
  15. self.patience = patience
  16. self.delta = delta
  17. self.counter = 0
  18. self.counter_up = 0
  19. self.last_loss = None
  20. self.early_stop = False
  21. self.patience_up = patience_up
  22. def __call__(self, loss):
  23. """当输入的loss多次不下降或者上升的时候,返回True,正常时返回False
  24. Args:
  25. loss (float): 当前的损失值
  26. Returns:
  27. bool: 是否早停
  28. """
  29. if self.last_loss is None:
  30. self.last_loss = loss
  31. return False
  32. # loss下降明显低于delta,当前清零
  33. if loss < self.last_loss - self.delta:
  34. self.counter = 0
  35. self.counter_up = 0
  36. self.last_loss = loss
  37. # loss上升明显高于delta,counter_up开始计数
  38. elif loss > self.last_loss + self.delta:
  39. self.counter_up += 1
  40. if self.counter_up >= self.patience_up:
  41. self.early_stop = True
  42. return True
  43. # loss上升和下降均小于delta,在区间震荡,counter开始计数
  44. else:
  45. self.counter += 1
  46. if self.counter >= self.patience:
  47. self.early_stop = True
  48. return True
  49. return False
  50. def accuracy(model, dataloader, zeroshot_weights):
  51. """计算模型的准确率"""
  52. model.eval()
  53. corrct = 0
  54. total_count = 0
  55. with jt.no_grad():
  56. for i, batch in enumerate(dataloader):
  57. images, targets, texts = batch
  58. total_count += len(images)
  59. image_features = model.encode_image(images)
  60. image_features = image_features / image_features.norm(dim=1, keepdim=True)
  61. logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
  62. preds = jt.argmax(logits, dim=1)[0]
  63. corrct += jt.equal(preds, targets).sum().item()
  64. return corrct / total_count
  65. def get_current_date(end_time='day'):
  66. # 获取当前日期时间对象
  67. current_date = datetime.now()
  68. # 格式化日期为月日时分格式
  69. if end_time == 'day':
  70. formatted_date = current_date.strftime("%m-%d")
  71. elif end_time == 'minute':
  72. formatted_date = current_date.strftime("%m-%d_%H:%M")
  73. return formatted_date
  74. def get_save_path(given_path, optimizer):
  75. """获取tensorboard日志/模型保存路径"""
  76. # 文件保存路径如下:
  77. # given_path/date/optimizer/version_x
  78. path = os.path.join(given_path, get_current_date(end_time='day'))
  79. os.makedirs(path, exist_ok=True)
  80. try:
  81. last_version = int(natsorted(os.listdir(path))[-1].split('_')[-1])
  82. current_path = os.path.join(path, f'version_{last_version + 1}')
  83. os.makedirs(current_path, exist_ok=True)
  84. except IndexError:
  85. current_path = os.path.join(path, 'version_0')
  86. os.makedirs(current_path, exist_ok=True)
  87. return current_path
  88. def get_optimizer(args, model):
  89. """根据输入参数获取优化器"""
  90. if args.optimizer == 'Adam':
  91. optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  92. betas=args.betas, eps=args.eps)
  93. elif args.optimizer == 'AdamW':
  94. optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  95. betas=args.betas, eps=args.eps)
  96. elif args.optimizer == 'Adan':
  97. if len(args.betas) == 2:
  98. raise ValueError('Adan optimizer requires betas has the shape like (0.9,0.98, 0.99)')
  99. optimizer = Adan(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  100. betas=args.betas, eps=args.eps)
  101. elif args.optimizer == 'SGD':
  102. optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  103. momentum=0.9)
  104. else:
  105. raise ValueError('Unsupported optimizer, please check the optimizer name.')
  106. return optimizer
  107. def get_scheduler(optimizer, args):
  108. """根据输入参数获取学习率调度器"""
  109. pass
  110. def get_transform(args):
  111. """根据输入参数获取数据预处理"""
  112. if args.data_preprocess == 1:
  113. transforms = transform.Compose([
  114. transform.Resize(224, mode=Image.BICUBIC),
  115. transform.CenterCrop(224), lambda image: image.convert("RGB"),
  116. transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073),
  117. std=(0.26862954, 0.26130258, 0.27577711))
  118. ])
  119. return transforms
  120. elif args.data_preprocess == 2:
  121. transforms = transform.Compose([
  122. transform.Resize(224, mode=Image.BICUBIC),
  123. transform.CenterCrop(224), lambda image: image.convert("RGB"),
  124. transform.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.4, hue=0.1),
  125. transform.RandomRotation(10),
  126. transform.RandomHorizontalFlip(),
  127. transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073),
  128. std=(0.26862954, 0.26130258, 0.27577711))])
  129. def compute_loss(logits_image, logits_text):
  130. """计算损失函数,用来建立文本与图像的语义关系,实现语义对其"""
  131. ground_truth = jt.arange(len(logits_image), dtype=jt.int32)
  132. loss = (jt.nn.cross_entropy_loss(logits_image, ground_truth) +\
  133. jt.nn.cross_entropy_loss(logits_text, ground_truth)) / 2
  134. return loss

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)