| @@ -0,0 +1,192 @@ | |||
| import os | |||
| import random | |||
| import time | |||
| import jittor as jt | |||
| from PIL import Image | |||
| from jittor.dataset import Dataset | |||
| import jittor.transform as T | |||
| import jclip as clip | |||
| from path_config import DATASET_ROOT_DIR | |||
| jt.flags.use_cuda = 1 | |||
| SEED = 123 | |||
| # 训练数据集 | |||
| class CLIPDataset(Dataset): | |||
| """训练数据集""" | |||
| def __init__(self, | |||
| img_list, | |||
| txt_list, | |||
| label_list, | |||
| transform, | |||
| batch_size, | |||
| num_workers, | |||
| shuffle, | |||
| return_desc | |||
| ): | |||
| super(CLIPDataset, self).__init__() | |||
| self.image_path = img_list | |||
| self.label_list = label_list | |||
| self.texts = clip.tokenize(txt_list).numpy() | |||
| self.transform = transform | |||
| self.return_desc = return_desc | |||
| self.set_attrs(batch_size=batch_size, | |||
| num_workers=num_workers, | |||
| shuffle=shuffle) | |||
| def __len__(self): | |||
| return len(self.label_list) | |||
| def __getitem__(self, idx): | |||
| image = self.transform(Image.open(self.image_path[idx])) | |||
| text = self.texts[idx] | |||
| label = self.label_list[idx] | |||
| if self.return_desc: | |||
| return image, label, text | |||
| return image, label | |||
| def get_map_dict(): | |||
| """映射:从类别(数字)到文本""" | |||
| with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file: | |||
| classes = file.read().splitlines() | |||
| map_dict = {} | |||
| id_list = [int(line.split(' ')[1]) for line in classes] | |||
| classnames = [line.split(' ')[0] for line in classes] | |||
| for idx, name in zip(id_list, classnames): | |||
| if idx < 52: | |||
| map_dict[idx] = ' '.join(name[7:].lower().split('_')) | |||
| elif idx < 143: | |||
| map_dict[idx] = ' '.join(name[12:].lower().split('_')) | |||
| elif idx < 244: | |||
| map_dict[idx] = ' '.join(name[9:].lower().split('_')) | |||
| else: | |||
| map_dict[idx] = ' '.join(name[8:].lower().split('_')) | |||
| return map_dict | |||
| def get_description(id_list, version, custom_desc=None): | |||
| """根据类别id获取对应的文本描述""" | |||
| id2cls = get_map_dict() | |||
| if version == 1: | |||
| return ['a photo of {}'.format(id2cls[idx]) for idx in id_list] | |||
| elif version == 2: | |||
| desc = [] | |||
| for idx in id_list: | |||
| if idx < 52: | |||
| desc.append('a photo of {}, a type of animal'.format(id2cls[idx])) | |||
| elif idx < 143: | |||
| desc.append('a photo of {}'.format(id2cls[idx])) | |||
| elif idx < 244: | |||
| desc.append('a photo of {}, a type of food'.format(id2cls[idx])) | |||
| else: | |||
| desc.append('a photo of {}, a type of dog'.format(id2cls[idx])) | |||
| return desc | |||
| elif version == 3: | |||
| assert custom_desc is not None, 'custon_desc must not be None' | |||
| return custom_desc | |||
| raise ValueError("version must be 1 or 2 or 3") | |||
| def split_data(seed, version, custom_desc=None): | |||
| """划分数据集,每个类别4张作为训练,其他用作验证""" | |||
| # Load the training data | |||
| with open(os.path.join(DATASET_ROOT_DIR, 'train.txt'), 'r') as file: | |||
| img_label_pairs = file.read().splitlines() | |||
| random.seed(seed) | |||
| random.shuffle(img_label_pairs) | |||
| total_paths = [l.split(' ')[0] for l in img_label_pairs] | |||
| total_labels = [int(l.split(' ')[1]) for l in img_label_pairs] | |||
| cnt = {} | |||
| train_paths = [] | |||
| train_labels = [] | |||
| # animal, caltech, food, dogs | |||
| test_paths = [[] for _ in range(4)] | |||
| test_labels = [[] for _ in range(4)] | |||
| for path, label in zip(total_paths, total_labels): | |||
| if label not in cnt: | |||
| cnt[label] = 0 | |||
| if cnt[label] < 4: | |||
| train_paths.append(f'{DATASET_ROOT_DIR}/{path}') | |||
| train_labels.append(label) | |||
| cnt[label] += 1 | |||
| else: | |||
| if label < 52: | |||
| index = 0 | |||
| elif label < 143: | |||
| index = 1 | |||
| elif label < 244: | |||
| index = 2 | |||
| else: | |||
| index = 3 | |||
| test_paths[index].append(f'{DATASET_ROOT_DIR}/{path}') | |||
| test_labels[index].append(label) | |||
| train_text_desc = get_description(train_labels, version, custom_desc) | |||
| return { | |||
| 'train_paths': train_paths, | |||
| 'train_labels': train_labels, | |||
| 'train_text_desc': train_text_desc, | |||
| 'test_paths': test_paths, | |||
| 'test_labels': test_labels | |||
| } | |||
| def get_dataloader(transform, batch_size, num_workers, shuffle=True, version=2): | |||
| data = split_data(SEED, version=version) | |||
| return CLIPDataset(data['train_paths'], | |||
| data['train_text_desc'], | |||
| data['train_labels'], | |||
| transform, | |||
| batch_size=batch_size, | |||
| num_workers=num_workers, | |||
| shuffle=shuffle, | |||
| return_desc=True) | |||
| if __name__ == "__main__": | |||
| SEED = 123 | |||
| # load data only to find data read bottleneck | |||
| print('loading model...') | |||
| model, preprocess = clip.load("pretrained/ViT-B-32.pkl") | |||
| print('model loaded!') | |||
| data = split_data(SEED, version=2) | |||
| # preview descriptions | |||
| for item in data['train_text_desc']: | |||
| print(item) | |||
| train_loader = CLIPDataset(data['train_paths'], | |||
| data['train_text_desc'], | |||
| data['train_labels'], | |||
| preprocess, | |||
| batch_size=128, | |||
| num_workers=8, | |||
| shuffle=True, | |||
| return_desc=True) | |||
| times = [] | |||
| for epoch in range(10): | |||
| s = time.time() | |||
| for i, (img, text, label) in enumerate(train_loader): | |||
| print(img.shape, text.shape, label.shape) | |||
| e = time.time() | |||
| times.append(e - s) | |||
| print(f'cost time: {e - s}') | |||
| print(f'average cost time: {sum(times) / len(times)}') | |||