import os import random import time import jittor as jt from PIL import Image from jittor.dataset import Dataset import jittor.transform as T import jclip as clip from path_config import DATASET_ROOT_DIR jt.flags.use_cuda = 1 SEED = 123 # 训练数据集 class CLIPDataset(Dataset): """训练数据集""" def __init__(self, img_list, txt_list, label_list, transform, batch_size, num_workers, shuffle, return_desc ): super(CLIPDataset, self).__init__() self.image_path = img_list self.label_list = label_list self.texts = clip.tokenize(txt_list).numpy() self.transform = transform self.return_desc = return_desc self.set_attrs(batch_size=batch_size, num_workers=num_workers, shuffle=shuffle) def __len__(self): return len(self.label_list) def __getitem__(self, idx): image = self.transform(Image.open(self.image_path[idx])) text = self.texts[idx] label = self.label_list[idx] if self.return_desc: return image, label, text return image, label def get_map_dict(): """映射:从类别(数字)到文本""" with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file: classes = file.read().splitlines() map_dict = {} id_list = [int(line.split(' ')[1]) for line in classes] classnames = [line.split(' ')[0] for line in classes] for idx, name in zip(id_list, classnames): if idx < 52: map_dict[idx] = ' '.join(name[7:].lower().split('_')) elif idx < 143: map_dict[idx] = ' '.join(name[12:].lower().split('_')) elif idx < 244: map_dict[idx] = ' '.join(name[9:].lower().split('_')) else: map_dict[idx] = ' '.join(name[8:].lower().split('_')) return map_dict def get_description(id_list, version, custom_desc=None): """根据类别id获取对应的文本描述""" id2cls = get_map_dict() if version == 1: return ['a photo of {}'.format(id2cls[idx]) for idx in id_list] elif version == 2: desc = [] for idx in id_list: if idx < 52: desc.append('a photo of {}, a type of animal'.format(id2cls[idx])) elif idx < 143: desc.append('a photo of {}'.format(id2cls[idx])) elif idx < 244: desc.append('a photo of {}, a type of food'.format(id2cls[idx])) else: desc.append('a photo of {}, a type of dog'.format(id2cls[idx])) return desc elif version == 3: assert custom_desc is not None, 'custon_desc must not be None' return custom_desc raise ValueError("version must be 1 or 2 or 3") def split_data(seed, version, custom_desc=None): """划分数据集,每个类别4张作为训练,其他用作验证""" # Load the training data with open(os.path.join(DATASET_ROOT_DIR, 'train.txt'), 'r') as file: img_label_pairs = file.read().splitlines() random.seed(seed) random.shuffle(img_label_pairs) total_paths = [l.split(' ')[0] for l in img_label_pairs] total_labels = [int(l.split(' ')[1]) for l in img_label_pairs] cnt = {} train_paths = [] train_labels = [] # animal, caltech, food, dogs test_paths = [[] for _ in range(4)] test_labels = [[] for _ in range(4)] for path, label in zip(total_paths, total_labels): if label not in cnt: cnt[label] = 0 if cnt[label] < 4: train_paths.append(f'{DATASET_ROOT_DIR}/{path}') train_labels.append(label) cnt[label] += 1 else: if label < 52: index = 0 elif label < 143: index = 1 elif label < 244: index = 2 else: index = 3 test_paths[index].append(f'{DATASET_ROOT_DIR}/{path}') test_labels[index].append(label) train_text_desc = get_description(train_labels, version, custom_desc) return { 'train_paths': train_paths, 'train_labels': train_labels, 'train_text_desc': train_text_desc, 'test_paths': test_paths, 'test_labels': test_labels } def get_dataloader(transform, batch_size, num_workers, shuffle=True, version=2): data = split_data(SEED, version=version) return CLIPDataset(data['train_paths'], data['train_text_desc'], data['train_labels'], transform, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, return_desc=True) if __name__ == "__main__": SEED = 123 # load data only to find data read bottleneck print('loading model...') model, preprocess = clip.load("pretrained/ViT-B-32.pkl") print('model loaded!') data = split_data(SEED, version=2) # preview descriptions for item in data['train_text_desc']: print(item) train_loader = CLIPDataset(data['train_paths'], data['train_text_desc'], data['train_labels'], preprocess, batch_size=128, num_workers=8, shuffle=True, return_desc=True) times = [] for epoch in range(10): s = time.time() for i, (img, text, label) in enumerate(train_loader): print(img.shape, text.shape, label.shape) e = time.time() times.append(e - s) print(f'cost time: {e - s}') print(f'average cost time: {sum(times) / len(times)}')