You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import os
  2. import random
  3. import time
  4. import jittor as jt
  5. from PIL import Image
  6. from jittor.dataset import Dataset
  7. import jittor.transform as T
  8. import jclip as clip
  9. from path_config import DATASET_ROOT_DIR
  10. jt.flags.use_cuda = 1
  11. SEED = 123
  12. # 训练数据集
  13. class CLIPDataset(Dataset):
  14. """训练数据集"""
  15. def __init__(self,
  16. img_list,
  17. txt_list,
  18. label_list,
  19. transform,
  20. batch_size,
  21. num_workers,
  22. shuffle,
  23. return_desc
  24. ):
  25. super(CLIPDataset, self).__init__()
  26. self.image_path = img_list
  27. self.label_list = label_list
  28. self.texts = clip.tokenize(txt_list).numpy()
  29. self.transform = transform
  30. self.return_desc = return_desc
  31. self.set_attrs(batch_size=batch_size,
  32. num_workers=num_workers,
  33. shuffle=shuffle)
  34. def __len__(self):
  35. return len(self.label_list)
  36. def __getitem__(self, idx):
  37. image = self.transform(Image.open(self.image_path[idx]))
  38. text = self.texts[idx]
  39. label = self.label_list[idx]
  40. if self.return_desc:
  41. return image, label, text
  42. return image, label
  43. def get_map_dict():
  44. """映射:从类别(数字)到文本"""
  45. with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file:
  46. classes = file.read().splitlines()
  47. map_dict = {}
  48. id_list = [int(line.split(' ')[1]) for line in classes]
  49. classnames = [line.split(' ')[0] for line in classes]
  50. for idx, name in zip(id_list, classnames):
  51. if idx < 52:
  52. map_dict[idx] = ' '.join(name[7:].lower().split('_'))
  53. elif idx < 143:
  54. map_dict[idx] = ' '.join(name[12:].lower().split('_'))
  55. elif idx < 244:
  56. map_dict[idx] = ' '.join(name[9:].lower().split('_'))
  57. else:
  58. map_dict[idx] = ' '.join(name[8:].lower().split('_'))
  59. return map_dict
  60. def get_description(id_list, version, custom_desc=None):
  61. """根据类别id获取对应的文本描述"""
  62. id2cls = get_map_dict()
  63. if version == 1:
  64. return ['a photo of {}'.format(id2cls[idx]) for idx in id_list]
  65. elif version == 2:
  66. desc = []
  67. for idx in id_list:
  68. if idx < 52:
  69. desc.append('a photo of {}, a type of animal'.format(id2cls[idx]))
  70. elif idx < 143:
  71. desc.append('a photo of {}'.format(id2cls[idx]))
  72. elif idx < 244:
  73. desc.append('a photo of {}, a type of food'.format(id2cls[idx]))
  74. else:
  75. desc.append('a photo of {}, a type of dog'.format(id2cls[idx]))
  76. return desc
  77. elif version == 3:
  78. assert custom_desc is not None, 'custon_desc must not be None'
  79. return custom_desc
  80. raise ValueError("version must be 1 or 2 or 3")
  81. def split_data(seed, version, custom_desc=None):
  82. """划分数据集,每个类别4张作为训练,其他用作验证"""
  83. # Load the training data
  84. with open(os.path.join(DATASET_ROOT_DIR, 'train.txt'), 'r') as file:
  85. img_label_pairs = file.read().splitlines()
  86. random.seed(seed)
  87. random.shuffle(img_label_pairs)
  88. total_paths = [l.split(' ')[0] for l in img_label_pairs]
  89. total_labels = [int(l.split(' ')[1]) for l in img_label_pairs]
  90. cnt = {}
  91. train_paths = []
  92. train_labels = []
  93. # animal, caltech, food, dogs
  94. test_paths = [[] for _ in range(4)]
  95. test_labels = [[] for _ in range(4)]
  96. for path, label in zip(total_paths, total_labels):
  97. if label not in cnt:
  98. cnt[label] = 0
  99. if cnt[label] < 4:
  100. train_paths.append(f'{DATASET_ROOT_DIR}/{path}')
  101. train_labels.append(label)
  102. cnt[label] += 1
  103. else:
  104. if label < 52:
  105. index = 0
  106. elif label < 143:
  107. index = 1
  108. elif label < 244:
  109. index = 2
  110. else:
  111. index = 3
  112. test_paths[index].append(f'{DATASET_ROOT_DIR}/{path}')
  113. test_labels[index].append(label)
  114. train_text_desc = get_description(train_labels, version, custom_desc)
  115. return {
  116. 'train_paths': train_paths,
  117. 'train_labels': train_labels,
  118. 'train_text_desc': train_text_desc,
  119. 'test_paths': test_paths,
  120. 'test_labels': test_labels
  121. }
  122. def get_dataloader(transform, batch_size, num_workers, shuffle=True, version=2):
  123. data = split_data(SEED, version=version)
  124. return CLIPDataset(data['train_paths'],
  125. data['train_text_desc'],
  126. data['train_labels'],
  127. transform,
  128. batch_size=batch_size,
  129. num_workers=num_workers,
  130. shuffle=shuffle,
  131. return_desc=True)
  132. if __name__ == "__main__":
  133. SEED = 123
  134. # load data only to find data read bottleneck
  135. print('loading model...')
  136. model, preprocess = clip.load("pretrained/ViT-B-32.pkl")
  137. print('model loaded!')
  138. data = split_data(SEED, version=2)
  139. # preview descriptions
  140. for item in data['train_text_desc']:
  141. print(item)
  142. train_loader = CLIPDataset(data['train_paths'],
  143. data['train_text_desc'],
  144. data['train_labels'],
  145. preprocess,
  146. batch_size=128,
  147. num_workers=8,
  148. shuffle=True,
  149. return_desc=True)
  150. times = []
  151. for epoch in range(10):
  152. s = time.time()
  153. for i, (img, text, label) in enumerate(train_loader):
  154. print(img.shape, text.shape, label.shape)
  155. e = time.time()
  156. times.append(e - s)
  157. print(f'cost time: {e - s}')
  158. print(f'average cost time: {sum(times) / len(times)}')

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)