| @@ -0,0 +1,192 @@ | |||||
| import os | |||||
| import random | |||||
| import time | |||||
| import jittor as jt | |||||
| from PIL import Image | |||||
| from jittor.dataset import Dataset | |||||
| import jittor.transform as T | |||||
| import jclip as clip | |||||
| from path_config import DATASET_ROOT_DIR | |||||
| jt.flags.use_cuda = 1 | |||||
| SEED = 123 | |||||
| # 训练数据集 | |||||
| class CLIPDataset(Dataset): | |||||
| """训练数据集""" | |||||
| def __init__(self, | |||||
| img_list, | |||||
| txt_list, | |||||
| label_list, | |||||
| transform, | |||||
| batch_size, | |||||
| num_workers, | |||||
| shuffle, | |||||
| return_desc | |||||
| ): | |||||
| super(CLIPDataset, self).__init__() | |||||
| self.image_path = img_list | |||||
| self.label_list = label_list | |||||
| self.texts = clip.tokenize(txt_list).numpy() | |||||
| self.transform = transform | |||||
| self.return_desc = return_desc | |||||
| self.set_attrs(batch_size=batch_size, | |||||
| num_workers=num_workers, | |||||
| shuffle=shuffle) | |||||
| def __len__(self): | |||||
| return len(self.label_list) | |||||
| def __getitem__(self, idx): | |||||
| image = self.transform(Image.open(self.image_path[idx])) | |||||
| text = self.texts[idx] | |||||
| label = self.label_list[idx] | |||||
| if self.return_desc: | |||||
| return image, label, text | |||||
| return image, label | |||||
| def get_map_dict(): | |||||
| """映射:从类别(数字)到文本""" | |||||
| with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file: | |||||
| classes = file.read().splitlines() | |||||
| map_dict = {} | |||||
| id_list = [int(line.split(' ')[1]) for line in classes] | |||||
| classnames = [line.split(' ')[0] for line in classes] | |||||
| for idx, name in zip(id_list, classnames): | |||||
| if idx < 52: | |||||
| map_dict[idx] = ' '.join(name[7:].lower().split('_')) | |||||
| elif idx < 143: | |||||
| map_dict[idx] = ' '.join(name[12:].lower().split('_')) | |||||
| elif idx < 244: | |||||
| map_dict[idx] = ' '.join(name[9:].lower().split('_')) | |||||
| else: | |||||
| map_dict[idx] = ' '.join(name[8:].lower().split('_')) | |||||
| return map_dict | |||||
| def get_description(id_list, version, custom_desc=None): | |||||
| """根据类别id获取对应的文本描述""" | |||||
| id2cls = get_map_dict() | |||||
| if version == 1: | |||||
| return ['a photo of {}'.format(id2cls[idx]) for idx in id_list] | |||||
| elif version == 2: | |||||
| desc = [] | |||||
| for idx in id_list: | |||||
| if idx < 52: | |||||
| desc.append('a photo of {}, a type of animal'.format(id2cls[idx])) | |||||
| elif idx < 143: | |||||
| desc.append('a photo of {}'.format(id2cls[idx])) | |||||
| elif idx < 244: | |||||
| desc.append('a photo of {}, a type of food'.format(id2cls[idx])) | |||||
| else: | |||||
| desc.append('a photo of {}, a type of dog'.format(id2cls[idx])) | |||||
| return desc | |||||
| elif version == 3: | |||||
| assert custom_desc is not None, 'custon_desc must not be None' | |||||
| return custom_desc | |||||
| raise ValueError("version must be 1 or 2 or 3") | |||||
| def split_data(seed, version, custom_desc=None): | |||||
| """划分数据集,每个类别4张作为训练,其他用作验证""" | |||||
| # Load the training data | |||||
| with open(os.path.join(DATASET_ROOT_DIR, 'train.txt'), 'r') as file: | |||||
| img_label_pairs = file.read().splitlines() | |||||
| random.seed(seed) | |||||
| random.shuffle(img_label_pairs) | |||||
| total_paths = [l.split(' ')[0] for l in img_label_pairs] | |||||
| total_labels = [int(l.split(' ')[1]) for l in img_label_pairs] | |||||
| cnt = {} | |||||
| train_paths = [] | |||||
| train_labels = [] | |||||
| # animal, caltech, food, dogs | |||||
| test_paths = [[] for _ in range(4)] | |||||
| test_labels = [[] for _ in range(4)] | |||||
| for path, label in zip(total_paths, total_labels): | |||||
| if label not in cnt: | |||||
| cnt[label] = 0 | |||||
| if cnt[label] < 4: | |||||
| train_paths.append(f'{DATASET_ROOT_DIR}/{path}') | |||||
| train_labels.append(label) | |||||
| cnt[label] += 1 | |||||
| else: | |||||
| if label < 52: | |||||
| index = 0 | |||||
| elif label < 143: | |||||
| index = 1 | |||||
| elif label < 244: | |||||
| index = 2 | |||||
| else: | |||||
| index = 3 | |||||
| test_paths[index].append(f'{DATASET_ROOT_DIR}/{path}') | |||||
| test_labels[index].append(label) | |||||
| train_text_desc = get_description(train_labels, version, custom_desc) | |||||
| return { | |||||
| 'train_paths': train_paths, | |||||
| 'train_labels': train_labels, | |||||
| 'train_text_desc': train_text_desc, | |||||
| 'test_paths': test_paths, | |||||
| 'test_labels': test_labels | |||||
| } | |||||
| def get_dataloader(transform, batch_size, num_workers, shuffle=True, version=2): | |||||
| data = split_data(SEED, version=version) | |||||
| return CLIPDataset(data['train_paths'], | |||||
| data['train_text_desc'], | |||||
| data['train_labels'], | |||||
| transform, | |||||
| batch_size=batch_size, | |||||
| num_workers=num_workers, | |||||
| shuffle=shuffle, | |||||
| return_desc=True) | |||||
| if __name__ == "__main__": | |||||
| SEED = 123 | |||||
| # load data only to find data read bottleneck | |||||
| print('loading model...') | |||||
| model, preprocess = clip.load("pretrained/ViT-B-32.pkl") | |||||
| print('model loaded!') | |||||
| data = split_data(SEED, version=2) | |||||
| # preview descriptions | |||||
| for item in data['train_text_desc']: | |||||
| print(item) | |||||
| train_loader = CLIPDataset(data['train_paths'], | |||||
| data['train_text_desc'], | |||||
| data['train_labels'], | |||||
| preprocess, | |||||
| batch_size=128, | |||||
| num_workers=8, | |||||
| shuffle=True, | |||||
| return_desc=True) | |||||
| times = [] | |||||
| for epoch in range(10): | |||||
| s = time.time() | |||||
| for i, (img, text, label) in enumerate(train_loader): | |||||
| print(img.shape, text.shape, label.shape) | |||||
| e = time.time() | |||||
| times.append(e - s) | |||||
| print(f'cost time: {e - s}') | |||||
| print(f'average cost time: {sum(times) / len(times)}') | |||||