You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import jittor as jt
  4. from jittor import transform
  5. from jittor.optim import Adam, AdamW, SGD, Adan, AdanBelief
  6. from PIL import Image
  7. from datetime import datetime
  8. from natsort import natsorted
  9. # from jt.optim.Optimizer import AdamwBelief
  10. class EarlyStop(object):
  11. """早停
  12. 1. 当模型的损失长时间不下降时,停止训练
  13. 2. 当模型的损失长时间增大时,也提前停止训练
  14. """
  15. def __init__(self, patience=7, delta=0.0001, patience_up=20):
  16. self.patience = patience
  17. self.delta = delta
  18. self.counter = 0
  19. self.counter_up = 0
  20. self.last_loss = None
  21. self.early_stop = False
  22. self.patience_up = patience_up
  23. def __call__(self, loss):
  24. """当输入的loss多次不下降或者上升的时候,返回True,正常时返回False
  25. Args:
  26. loss (float): 当前的损失值
  27. Returns:
  28. bool: 是否早停
  29. """
  30. if self.last_loss is None:
  31. self.last_loss = loss
  32. return False
  33. # loss下降明显低于delta,当前清零
  34. if loss < self.last_loss - self.delta:
  35. self.counter = 0
  36. self.counter_up = 0
  37. self.last_loss = loss
  38. # loss上升明显高于delta,counter_up开始计数
  39. elif loss > self.last_loss + self.delta:
  40. self.counter_up += 1
  41. if self.counter_up >= self.patience_up:
  42. self.early_stop = True
  43. return True
  44. # loss上升和下降均小于delta,在区间震荡,counter开始计数
  45. else:
  46. self.counter += 1
  47. if self.counter >= self.patience:
  48. self.early_stop = True
  49. return True
  50. return False
  51. def accuracy(model, dataloader, zeroshot_weights):
  52. """计算模型的准确率"""
  53. model.eval()
  54. corrct = 0
  55. total_count = 0
  56. with jt.no_grad():
  57. for i, batch in enumerate(dataloader):
  58. images, targets, texts = batch
  59. total_count += len(images)
  60. image_features = model.encode_image(images)
  61. image_features = image_features / image_features.norm(dim=1, keepdim=True)
  62. logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
  63. preds = jt.argmax(logits, dim=1)[0]
  64. corrct += jt.equal(preds, targets).sum().item()
  65. return corrct / total_count
  66. def get_current_date(end_time='day'):
  67. # 获取当前日期时间对象
  68. current_date = datetime.now()
  69. # 格式化日期为月日时分格式
  70. if end_time == 'day':
  71. formatted_date = current_date.strftime("%m-%d")
  72. elif end_time == 'minute':
  73. formatted_date = current_date.strftime("%m-%d_%H:%M")
  74. return formatted_date
  75. def get_save_path(given_path, optimizer):
  76. """获取tensorboard日志/模型保存路径"""
  77. # 文件保存路径如下:
  78. # given_path/date/optimizer/version_x
  79. path = os.path.join(given_path, get_current_date(end_time='day'))
  80. os.makedirs(path, exist_ok=True)
  81. try:
  82. last_version = int(natsorted(os.listdir(path))[-1].split('_')[-1])
  83. current_path = os.path.join(path, f'version_{last_version + 1}')
  84. os.makedirs(current_path, exist_ok=True)
  85. except IndexError:
  86. current_path = os.path.join(path, 'version_0')
  87. os.makedirs(current_path, exist_ok=True)
  88. return current_path
  89. def get_optimizer(args, model):
  90. """根据输入参数获取优化器"""
  91. if args.optimizer == 'Adam':
  92. optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  93. betas=args.betas, eps=args.eps)
  94. elif args.optimizer == 'AdamW':
  95. optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  96. betas=args.betas, eps=args.eps)
  97. elif args.optimizer == 'Adan':
  98. if len(args.betas) == 2:
  99. raise ValueError('Adan optimizer requires betas has the shape like (0.9,0.98, 0.99)')
  100. optimizer = Adan(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  101. betas=args.betas, eps=args.eps)
  102. elif args.optimizer == 'AdanBelief':
  103. if len(args.betas) == 2:
  104. raise ValueError('AdanBelief optimizer requires betas has the shape like (0.9,0.98, 0.99)')
  105. optimizer = AdanBelief(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  106. betas=args.betas, eps=args.eps)
  107. elif args.optimizer == 'SGD':
  108. optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
  109. momentum=0.937)
  110. else:
  111. raise ValueError('Unsupported optimizer, please check the optimizer name.')
  112. return optimizer
  113. def get_scheduler(optimizer, args):
  114. """根据输入参数获取学习率调度器"""
  115. pass
  116. def get_transform(args):
  117. """根据输入参数获取数据预处理"""
  118. if args.data_preprocess == 1:
  119. transforms = transform.Compose([
  120. transform.Resize(224, mode=Image.BICUBIC),
  121. transform.CenterCrop(224), lambda image: image.convert("RGB"),
  122. transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073),
  123. std=(0.26862954, 0.26130258, 0.27577711))
  124. ])
  125. return transforms
  126. elif args.data_preprocess == 2:
  127. transforms = transform.Compose([
  128. transform.Resize(224, mode=Image.BICUBIC),
  129. transform.CenterCrop(224), lambda image: image.convert("RGB"),
  130. transform.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.4, hue=0.1),
  131. transform.RandomRotation(10),
  132. transform.RandomHorizontalFlip(),
  133. transform.ImageNormalize(mean=(0.48145466, 0.4578275, 0.40821073),
  134. std=(0.26862954, 0.26130258, 0.27577711))])
  135. def compute_loss(logits_image, logits_text):
  136. """计算损失函数,用来建立文本与图像的语义关系,实现语义对其"""
  137. ground_truth = jt.arange(len(logits_image), dtype=jt.int32)
  138. loss = (jt.nn.cross_entropy_loss(logits_image, ground_truth) +\
  139. jt.nn.cross_entropy_loss(logits_text, ground_truth)) / 2
  140. # loss = (jt.nn.cross_entropy_loss(logits_image, ground_truth) +\
  141. # jt.nn.smooth_l1_loss(logits_text, ground_truth)) / 2
  142. return loss

首先冻结OpenAI官方预训练的ViT-B/32版本的CLIP模型中的全部图像层,再利用AdanBelief优化器训练模型,该优化器是Adan优化器和AdaBelief优化器的融合,在Adan优化器中融入"Belief"增强训练模型的泛化性能。

Contributors (1)