| @@ -0,0 +1,284 @@ | |||||
| import os | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
| import functools | |||||
| from typing import List | |||||
| import jittor as jt | |||||
| import numpy as np | |||||
| from tqdm import tqdm | |||||
| from PIL import Image | |||||
| import jclip as clip | |||||
| from tabulate import tabulate | |||||
| from template import templatesv1, templatesv2, templatesv3 | |||||
| from path_config import DATASET_ROOT_DIR | |||||
| from utils.logger import setup_logger | |||||
| from utils.color_print import yellow_print | |||||
| jt.flags.use_cuda = 1 | |||||
| @functools.lru_cache() | |||||
| def get_valid_dataset(): | |||||
| dataset_path = os.path.join(DATASET_ROOT_DIR, 'valid.lst') | |||||
| img_label_pairs = open(dataset_path, 'r').read().splitlines() | |||||
| image_list = [l.split(' ')[0] for l in img_label_pairs] | |||||
| id_list = [int(l.split(' ')[1]) for l in img_label_pairs] | |||||
| return image_list, id_list | |||||
| def split_valid_dataset(image_list, id_list): | |||||
| valid_dataset = { | |||||
| 'animal': { | |||||
| 'image_list': [], | |||||
| 'id_list': [] | |||||
| }, | |||||
| 'caltech': { | |||||
| 'image_list': [], | |||||
| 'id_list': [] | |||||
| }, | |||||
| 'dogs': { | |||||
| 'image_list': [], | |||||
| 'id_list': [] | |||||
| }, | |||||
| 'food': { | |||||
| 'image_list': [], | |||||
| 'id_list': [] | |||||
| } | |||||
| } | |||||
| for image, label in zip(image_list, id_list): | |||||
| if label < 52: | |||||
| valid_dataset['animal']['image_list'].append(image) | |||||
| valid_dataset['animal']['id_list'].append(label) | |||||
| elif label < 143: | |||||
| valid_dataset['caltech']['image_list'].append(image) | |||||
| valid_dataset['caltech']['id_list'].append(label) | |||||
| elif label < 244: | |||||
| valid_dataset['food']['image_list'].append(image) | |||||
| valid_dataset['food']['id_list'].append(label) | |||||
| else: | |||||
| valid_dataset['dogs']['image_list'].append(image) | |||||
| valid_dataset['dogs']['id_list'].append(label) | |||||
| return valid_dataset | |||||
| class TestSet(jt.dataset.Dataset): | |||||
| def __init__(self, | |||||
| image_list, | |||||
| id_list, | |||||
| transform, | |||||
| batch_size=256, | |||||
| num_workers=8): | |||||
| super(TestSet, self).__init__() | |||||
| self.image_list = image_list | |||||
| self.id_list = id_list | |||||
| self.batch_size = batch_size | |||||
| self.num_workers = num_workers | |||||
| self.transform = transform | |||||
| self.set_attrs( | |||||
| batch_size=self.batch_size, | |||||
| total_len=len(self.image_list), | |||||
| num_workers=self.num_workers, | |||||
| buffer_size=1024 * 1024 * 1024 | |||||
| ) | |||||
| def __getitem__(self, idx): | |||||
| image_path = self.image_list[idx] | |||||
| label = self.id_list[idx] | |||||
| image = Image.open(f'{DATASET_ROOT_DIR}/{image_path}').convert('RGB') | |||||
| if self.transform is not None: | |||||
| image = self.transform(image) | |||||
| img = np.asarray(image) | |||||
| return img, label | |||||
| @functools.lru_cache() | |||||
| def get_classnames() -> List[str]: | |||||
| """映射:从类别(数字)到文本""" | |||||
| with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file: | |||||
| classes = file.read().splitlines() | |||||
| map_dict = {} | |||||
| id_list = [int(line.split(' ')[1]) for line in classes] | |||||
| classnames = [line.split(' ')[0] for line in classes] | |||||
| for idx, name in zip(id_list, classnames): | |||||
| if idx < 52: | |||||
| map_dict[idx] = ' '.join(name[7:].lower().split('_')) | |||||
| elif idx < 143: | |||||
| map_dict[idx] = ' '.join(name[12:].lower().split('_')) | |||||
| elif idx < 244: | |||||
| map_dict[idx] = ' '.join(name[9:].lower().split('_')) | |||||
| else: | |||||
| map_dict[idx] = ' '.join(name[8:].lower().split('_')) | |||||
| return list(map_dict.values()) | |||||
| def zeroshot_classifier(model: clip.model.CLIP, | |||||
| classnames: List[str] = get_classnames(), | |||||
| weights_version: int = 1) -> jt.Var: | |||||
| """ | |||||
| 使用 CLIP 模型进行零样本分类器的构建。 | |||||
| Args: | |||||
| - model (clip.model.CLIP): 加载的 CLIP 模型实例。 | |||||
| - classnames (list): 包含所有类别名称的列表。 | |||||
| - templates (list): 包含模板字符串的列表,用于生成每个类别的文本输入。 | |||||
| Returns: | |||||
| - torch.Tensor: 形状为 (embedding_size, num_classes) 的零样本权重张量。 | |||||
| 注: | |||||
| - 此函数假设模型已经在 GPU 上,并且模型可以处理字符串的 tokenization 和文本嵌入。 | |||||
| - 输出的张量将包含每个类别的平均嵌入向量,并进行了归一化处理。 | |||||
| """ | |||||
| if weights_version == 1: | |||||
| templates = templatesv1 | |||||
| elif weights_version == 2: | |||||
| templates = templatesv2 | |||||
| elif weights_version == 3: | |||||
| templates = templatesv3 | |||||
| else: | |||||
| raise ValueError("weights_version must be 1, 2, or 3") | |||||
| model.eval() | |||||
| with jt.no_grad(): | |||||
| zeroshot_weights = [] | |||||
| for classname in tqdm(classnames, desc='Extracting class embeddings'): | |||||
| texts = [template.format(classname) for template in templates] #format with class | |||||
| texts = clip.tokenize(texts) #tokenize | |||||
| class_embeddings = model.encode_text(texts) #embed with text encoder | |||||
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |||||
| class_embedding = class_embeddings.mean(dim=0) | |||||
| class_embedding /= class_embedding.norm() | |||||
| zeroshot_weights.append(class_embedding) | |||||
| zeroshot_weights = jt.stack(zeroshot_weights, dim=1) | |||||
| return zeroshot_weights | |||||
| def evaluate(model, dataloader, zeroshot_weights, name): | |||||
| model.eval() | |||||
| corrct = 0 | |||||
| total_count = 0 | |||||
| with jt.no_grad(): | |||||
| print(f"\nTesting on {name}") | |||||
| bar = tqdm(dataloader) | |||||
| for i, batch in enumerate(bar): | |||||
| images, targets = batch | |||||
| total_count += len(images) | |||||
| image_features = model.encode_image(images) | |||||
| image_features = image_features / image_features.norm(dim=1, keepdim=True) | |||||
| logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) | |||||
| preds = jt.argmax(logits, dim=1)[0] | |||||
| corrct += jt.equal(preds, targets).sum().item() | |||||
| bar.set_description(f'{name} : {corrct}/{total_count} Acc: {corrct / total_count:.4f}') | |||||
| bar.update(1) | |||||
| bar.close() | |||||
| return corrct | |||||
| def main(best_model_path, weights_version): | |||||
| model, transform = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl") | |||||
| if best_model_path is not None: | |||||
| model.load_state_dict(jt.load(best_model_path)) | |||||
| yellow_print('Loaded weights from {}'.format(best_model_path)) | |||||
| else: | |||||
| yellow_print('weights not loaded! using pretrained weights') | |||||
| zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version) | |||||
| image_list, id_list = get_valid_dataset() | |||||
| valid_dataset = split_valid_dataset(image_list, id_list) | |||||
| batch_sizes = [256, 256, 512, 512] | |||||
| num_works = [4, 4, 8, 8] | |||||
| animal = TestSet(valid_dataset['animal']['image_list'], | |||||
| valid_dataset['animal']['id_list'], | |||||
| transform=transform, | |||||
| batch_size=batch_sizes[0], | |||||
| num_workers=num_works[0]) | |||||
| caltech = TestSet(valid_dataset['caltech']['image_list'], | |||||
| valid_dataset['caltech']['id_list'], | |||||
| transform=transform, | |||||
| batch_size=batch_sizes[1], | |||||
| num_workers=num_works[1]) | |||||
| food = TestSet(valid_dataset['food']['image_list'], | |||||
| valid_dataset['food']['id_list'], | |||||
| transform=transform, | |||||
| batch_size=batch_sizes[2], | |||||
| num_workers=num_works[2]) | |||||
| dogs = TestSet(valid_dataset['dogs']['image_list'], | |||||
| valid_dataset['dogs']['id_list'], | |||||
| transform=transform, | |||||
| batch_size=batch_sizes[3], | |||||
| num_workers=num_works[3] | |||||
| ) | |||||
| animal_total = len(valid_dataset['animal']['image_list']) | |||||
| caltech_total = len(valid_dataset['caltech']['image_list']) | |||||
| food_total = len(valid_dataset['food']['image_list']) | |||||
| dogs_total = len(valid_dataset['dogs']['image_list']) | |||||
| total = animal_total + caltech_total + food_total + dogs_total | |||||
| animal_correct = evaluate(model, animal, zeroshot_weights, 'animal') | |||||
| caltech_correct = evaluate(model, caltech, zeroshot_weights, 'caltech') | |||||
| food_correct = evaluate(model, food, zeroshot_weights, 'food') | |||||
| dogs_correct = evaluate(model, dogs, zeroshot_weights, 'dogs') | |||||
| correct_total = animal_correct + caltech_correct + food_correct + dogs_correct | |||||
| metrics = [animal_correct/ animal_total, | |||||
| caltech_correct/ caltech_total, | |||||
| food_correct/ food_total, | |||||
| dogs_correct/ dogs_total, | |||||
| correct_total / total] | |||||
| print('Average Acc: ', metrics[-1]) | |||||
| return [round(acc, 4) for acc in metrics] | |||||
| if __name__ == '__main__': | |||||
| # 待测试的模型路径,是一个文件夹 | |||||
| model_dir = '/jittor-competiiton/ckptFE/07-30/version_1' | |||||
| logger = setup_logger(model_dir, type_='test') | |||||
| # 需要测试的模型文件名, 如min_loss, 20, 30 | |||||
| model_name = ['min_loss', 20, 50, 70, 90, 100] | |||||
| # 需要测试的提示词版本 | |||||
| # 1. basic: a photo of | |||||
| # 2. custom: | |||||
| # 3. from imagenet | |||||
| test_weights_version = [1, 2, 3] | |||||
| table_header = ['Epoch', '提示词', 'Animal', 'Caltech', 'Food', 'Dogs', 'Total'] | |||||
| table_data = [] | |||||
| promot = { | |||||
| 1: 'basic', | |||||
| 2: 'custom', | |||||
| 3: 'imagenet' | |||||
| } | |||||
| for epoch in model_name: | |||||
| if isinstance(epoch, str): | |||||
| model_path = os.path.join(model_dir, 'min_loss.pth') | |||||
| elif isinstance(epoch, int): | |||||
| model_path = os.path.join(model_dir, f'epoch_{epoch}.pth') | |||||
| print(f'Testing with {model_path}') | |||||
| logger.info(f'Testing with {model_path}') | |||||
| for weights_version in test_weights_version: | |||||
| metrics = main(model_path, weights_version) | |||||
| metrics.insert(0, promot[weights_version]) | |||||
| metrics.insert(0, epoch) | |||||
| table_data.append(metrics) | |||||
| print(tabulate(table_data, headers=table_header, tablefmt='fancy_grid')) | |||||
| logger.info(tabulate(table_data, headers=table_header, tablefmt='fancy_grid')) | |||||