You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_clip.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import functools
  4. from typing import List
  5. import jittor as jt
  6. import numpy as np
  7. from tqdm import tqdm
  8. from PIL import Image
  9. import jclip as clip
  10. from tabulate import tabulate
  11. from template import templatesv1, templatesv2, templatesv3
  12. from path_config import DATASET_ROOT_DIR
  13. from utils.logger import setup_logger
  14. from utils.color_print import yellow_print
  15. jt.flags.use_cuda = 1
  16. @functools.lru_cache()
  17. def get_valid_dataset():
  18. dataset_path = os.path.join(DATASET_ROOT_DIR, 'valid.lst')
  19. img_label_pairs = open(dataset_path, 'r').read().splitlines()
  20. image_list = [l.split(' ')[0] for l in img_label_pairs]
  21. id_list = [int(l.split(' ')[1]) for l in img_label_pairs]
  22. return image_list, id_list
  23. def split_valid_dataset(image_list, id_list):
  24. valid_dataset = {
  25. 'animal': {
  26. 'image_list': [],
  27. 'id_list': []
  28. },
  29. 'caltech': {
  30. 'image_list': [],
  31. 'id_list': []
  32. },
  33. 'dogs': {
  34. 'image_list': [],
  35. 'id_list': []
  36. },
  37. 'food': {
  38. 'image_list': [],
  39. 'id_list': []
  40. }
  41. }
  42. for image, label in zip(image_list, id_list):
  43. if label < 52:
  44. valid_dataset['animal']['image_list'].append(image)
  45. valid_dataset['animal']['id_list'].append(label)
  46. elif label < 143:
  47. valid_dataset['caltech']['image_list'].append(image)
  48. valid_dataset['caltech']['id_list'].append(label)
  49. elif label < 244:
  50. valid_dataset['food']['image_list'].append(image)
  51. valid_dataset['food']['id_list'].append(label)
  52. else:
  53. valid_dataset['dogs']['image_list'].append(image)
  54. valid_dataset['dogs']['id_list'].append(label)
  55. return valid_dataset
  56. class TestSet(jt.dataset.Dataset):
  57. def __init__(self,
  58. image_list,
  59. id_list,
  60. transform,
  61. batch_size=256,
  62. num_workers=8):
  63. super(TestSet, self).__init__()
  64. self.image_list = image_list
  65. self.id_list = id_list
  66. self.batch_size = batch_size
  67. self.num_workers = num_workers
  68. self.transform = transform
  69. self.set_attrs(
  70. batch_size=self.batch_size,
  71. total_len=len(self.image_list),
  72. num_workers=self.num_workers,
  73. buffer_size=1024 * 1024 * 1024
  74. )
  75. def __getitem__(self, idx):
  76. image_path = self.image_list[idx]
  77. label = self.id_list[idx]
  78. image = Image.open(f'{DATASET_ROOT_DIR}/{image_path}').convert('RGB')
  79. if self.transform is not None:
  80. image = self.transform(image)
  81. img = np.asarray(image)
  82. return img, label
  83. @functools.lru_cache()
  84. def get_classnames() -> List[str]:
  85. """映射:从类别(数字)到文本"""
  86. with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file:
  87. classes = file.read().splitlines()
  88. map_dict = {}
  89. id_list = [int(line.split(' ')[1]) for line in classes]
  90. classnames = [line.split(' ')[0] for line in classes]
  91. for idx, name in zip(id_list, classnames):
  92. if idx < 52:
  93. map_dict[idx] = ' '.join(name[7:].lower().split('_'))
  94. elif idx < 143:
  95. map_dict[idx] = ' '.join(name[12:].lower().split('_'))
  96. elif idx < 244:
  97. map_dict[idx] = ' '.join(name[9:].lower().split('_'))
  98. else:
  99. map_dict[idx] = ' '.join(name[8:].lower().split('_'))
  100. return list(map_dict.values())
  101. def zeroshot_classifier(model: clip.model.CLIP,
  102. classnames: List[str] = get_classnames(),
  103. weights_version: int = 1) -> jt.Var:
  104. """
  105. 使用 CLIP 模型进行零样本分类器的构建。
  106. Args:
  107. - model (clip.model.CLIP): 加载的 CLIP 模型实例。
  108. - classnames (list): 包含所有类别名称的列表。
  109. - templates (list): 包含模板字符串的列表,用于生成每个类别的文本输入。
  110. Returns:
  111. - torch.Tensor: 形状为 (embedding_size, num_classes) 的零样本权重张量。
  112. 注:
  113. - 此函数假设模型已经在 GPU 上,并且模型可以处理字符串的 tokenization 和文本嵌入。
  114. - 输出的张量将包含每个类别的平均嵌入向量,并进行了归一化处理。
  115. """
  116. if weights_version == 1:
  117. templates = templatesv1
  118. elif weights_version == 2:
  119. templates = templatesv2
  120. elif weights_version == 3:
  121. templates = templatesv3
  122. else:
  123. raise ValueError("weights_version must be 1, 2, or 3")
  124. model.eval()
  125. with jt.no_grad():
  126. zeroshot_weights = []
  127. for classname in tqdm(classnames, desc='Extracting class embeddings'):
  128. texts = [template.format(classname) for template in templates] #format with class
  129. texts = clip.tokenize(texts) #tokenize
  130. class_embeddings = model.encode_text(texts) #embed with text encoder
  131. class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
  132. class_embedding = class_embeddings.mean(dim=0)
  133. class_embedding /= class_embedding.norm()
  134. zeroshot_weights.append(class_embedding)
  135. zeroshot_weights = jt.stack(zeroshot_weights, dim=1)
  136. return zeroshot_weights
  137. def evaluate(model, dataloader, zeroshot_weights, name):
  138. model.eval()
  139. corrct = 0
  140. total_count = 0
  141. with jt.no_grad():
  142. print(f"\nTesting on {name}")
  143. bar = tqdm(dataloader)
  144. for i, batch in enumerate(bar):
  145. images, targets = batch
  146. total_count += len(images)
  147. image_features = model.encode_image(images)
  148. image_features = image_features / image_features.norm(dim=1, keepdim=True)
  149. logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
  150. preds = jt.argmax(logits, dim=1)[0]
  151. corrct += jt.equal(preds, targets).sum().item()
  152. bar.set_description(f'{name} : {corrct}/{total_count} Acc: {corrct / total_count:.4f}')
  153. bar.update(1)
  154. bar.close()
  155. return corrct
  156. def main(best_model_path, weights_version):
  157. model, transform = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
  158. if best_model_path is not None:
  159. model.load_state_dict(jt.load(best_model_path))
  160. yellow_print('Loaded weights from {}'.format(best_model_path))
  161. else:
  162. yellow_print('weights not loaded! using pretrained weights')
  163. zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
  164. image_list, id_list = get_valid_dataset()
  165. valid_dataset = split_valid_dataset(image_list, id_list)
  166. batch_sizes = [256, 256, 512, 512]
  167. num_works = [4, 4, 8, 8]
  168. animal = TestSet(valid_dataset['animal']['image_list'],
  169. valid_dataset['animal']['id_list'],
  170. transform=transform,
  171. batch_size=batch_sizes[0],
  172. num_workers=num_works[0])
  173. caltech = TestSet(valid_dataset['caltech']['image_list'],
  174. valid_dataset['caltech']['id_list'],
  175. transform=transform,
  176. batch_size=batch_sizes[1],
  177. num_workers=num_works[1])
  178. food = TestSet(valid_dataset['food']['image_list'],
  179. valid_dataset['food']['id_list'],
  180. transform=transform,
  181. batch_size=batch_sizes[2],
  182. num_workers=num_works[2])
  183. dogs = TestSet(valid_dataset['dogs']['image_list'],
  184. valid_dataset['dogs']['id_list'],
  185. transform=transform,
  186. batch_size=batch_sizes[3],
  187. num_workers=num_works[3]
  188. )
  189. animal_total = len(valid_dataset['animal']['image_list'])
  190. caltech_total = len(valid_dataset['caltech']['image_list'])
  191. food_total = len(valid_dataset['food']['image_list'])
  192. dogs_total = len(valid_dataset['dogs']['image_list'])
  193. total = animal_total + caltech_total + food_total + dogs_total
  194. animal_correct = evaluate(model, animal, zeroshot_weights, 'animal')
  195. caltech_correct = evaluate(model, caltech, zeroshot_weights, 'caltech')
  196. food_correct = evaluate(model, food, zeroshot_weights, 'food')
  197. dogs_correct = evaluate(model, dogs, zeroshot_weights, 'dogs')
  198. correct_total = animal_correct + caltech_correct + food_correct + dogs_correct
  199. metrics = [animal_correct/ animal_total,
  200. caltech_correct/ caltech_total,
  201. food_correct/ food_total,
  202. dogs_correct/ dogs_total,
  203. correct_total / total]
  204. print('Average Acc: ', metrics[-1])
  205. return [round(acc, 4) for acc in metrics]
  206. if __name__ == '__main__':
  207. # 待测试的模型路径,是一个文件夹
  208. model_dir = '/jittor-competiiton/ckptFE/07-30/version_1'
  209. logger = setup_logger(model_dir, type_='test')
  210. # 需要测试的模型文件名, 如min_loss, 20, 30
  211. model_name = ['min_loss', 20, 50, 70, 90, 100]
  212. # 需要测试的提示词版本
  213. # 1. basic: a photo of
  214. # 2. custom:
  215. # 3. from imagenet
  216. test_weights_version = [1, 2, 3]
  217. table_header = ['Epoch', '提示词', 'Animal', 'Caltech', 'Food', 'Dogs', 'Total']
  218. table_data = []
  219. promot = {
  220. 1: 'basic',
  221. 2: 'custom',
  222. 3: 'imagenet'
  223. }
  224. for epoch in model_name:
  225. if isinstance(epoch, str):
  226. model_path = os.path.join(model_dir, 'min_loss.pth')
  227. elif isinstance(epoch, int):
  228. model_path = os.path.join(model_dir, f'epoch_{epoch}.pth')
  229. print(f'Testing with {model_path}')
  230. logger.info(f'Testing with {model_path}')
  231. for weights_version in test_weights_version:
  232. metrics = main(model_path, weights_version)
  233. metrics.insert(0, promot[weights_version])
  234. metrics.insert(0, epoch)
  235. table_data.append(metrics)
  236. print(tabulate(table_data, headers=table_header, tablefmt='fancy_grid'))
  237. logger.info(tabulate(table_data, headers=table_header, tablefmt='fancy_grid'))

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)