|
- import os
-
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
-
- import functools
-
- from typing import List
- import jittor as jt
- import numpy as np
- from tqdm import tqdm
- from PIL import Image
- import jclip as clip
- from tabulate import tabulate
- from template import templatesv1, templatesv2, templatesv3
- from path_config import DATASET_ROOT_DIR
- from utils.logger import setup_logger
- from utils.color_print import yellow_print
-
-
- jt.flags.use_cuda = 1
-
-
- @functools.lru_cache()
- def get_valid_dataset():
- dataset_path = os.path.join(DATASET_ROOT_DIR, 'valid.lst')
- img_label_pairs = open(dataset_path, 'r').read().splitlines()
- image_list = [l.split(' ')[0] for l in img_label_pairs]
- id_list = [int(l.split(' ')[1]) for l in img_label_pairs]
- return image_list, id_list
-
-
- def split_valid_dataset(image_list, id_list):
- valid_dataset = {
- 'animal': {
- 'image_list': [],
- 'id_list': []
- },
- 'caltech': {
- 'image_list': [],
- 'id_list': []
- },
- 'dogs': {
- 'image_list': [],
- 'id_list': []
- },
- 'food': {
- 'image_list': [],
- 'id_list': []
- }
- }
- for image, label in zip(image_list, id_list):
- if label < 52:
- valid_dataset['animal']['image_list'].append(image)
- valid_dataset['animal']['id_list'].append(label)
- elif label < 143:
- valid_dataset['caltech']['image_list'].append(image)
- valid_dataset['caltech']['id_list'].append(label)
- elif label < 244:
- valid_dataset['food']['image_list'].append(image)
- valid_dataset['food']['id_list'].append(label)
- else:
- valid_dataset['dogs']['image_list'].append(image)
- valid_dataset['dogs']['id_list'].append(label)
-
- return valid_dataset
-
- class TestSet(jt.dataset.Dataset):
- def __init__(self,
- image_list,
- id_list,
- transform,
- batch_size=256,
- num_workers=8):
- super(TestSet, self).__init__()
- self.image_list = image_list
- self.id_list = id_list
- self.batch_size = batch_size
- self.num_workers = num_workers
- self.transform = transform
-
- self.set_attrs(
- batch_size=self.batch_size,
- total_len=len(self.image_list),
- num_workers=self.num_workers,
- buffer_size=1024 * 1024 * 1024
- )
-
-
- def __getitem__(self, idx):
- image_path = self.image_list[idx]
- label = self.id_list[idx]
- image = Image.open(f'{DATASET_ROOT_DIR}/{image_path}').convert('RGB')
- if self.transform is not None:
- image = self.transform(image)
- img = np.asarray(image)
- return img, label
-
-
- @functools.lru_cache()
- def get_classnames() -> List[str]:
- """映射:从类别(数字)到文本"""
- with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file:
- classes = file.read().splitlines()
- map_dict = {}
- id_list = [int(line.split(' ')[1]) for line in classes]
- classnames = [line.split(' ')[0] for line in classes]
- for idx, name in zip(id_list, classnames):
- if idx < 52:
- map_dict[idx] = ' '.join(name[7:].lower().split('_'))
- elif idx < 143:
- map_dict[idx] = ' '.join(name[12:].lower().split('_'))
- elif idx < 244:
- map_dict[idx] = ' '.join(name[9:].lower().split('_'))
- else:
- map_dict[idx] = ' '.join(name[8:].lower().split('_'))
- return list(map_dict.values())
-
-
- def zeroshot_classifier(model: clip.model.CLIP,
- classnames: List[str] = get_classnames(),
- weights_version: int = 1) -> jt.Var:
- """
- 使用 CLIP 模型进行零样本分类器的构建。
-
- Args:
- - model (clip.model.CLIP): 加载的 CLIP 模型实例。
- - classnames (list): 包含所有类别名称的列表。
- - templates (list): 包含模板字符串的列表,用于生成每个类别的文本输入。
-
- Returns:
- - torch.Tensor: 形状为 (embedding_size, num_classes) 的零样本权重张量。
-
- 注:
- - 此函数假设模型已经在 GPU 上,并且模型可以处理字符串的 tokenization 和文本嵌入。
- - 输出的张量将包含每个类别的平均嵌入向量,并进行了归一化处理。
- """
-
- if weights_version == 1:
- templates = templatesv1
- elif weights_version == 2:
- templates = templatesv2
- elif weights_version == 3:
- templates = templatesv3
- else:
- raise ValueError("weights_version must be 1, 2, or 3")
-
- model.eval()
- with jt.no_grad():
- zeroshot_weights = []
- for classname in tqdm(classnames, desc='Extracting class embeddings'):
- texts = [template.format(classname) for template in templates] #format with class
- texts = clip.tokenize(texts) #tokenize
- class_embeddings = model.encode_text(texts) #embed with text encoder
- class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
- class_embedding = class_embeddings.mean(dim=0)
- class_embedding /= class_embedding.norm()
- zeroshot_weights.append(class_embedding)
- zeroshot_weights = jt.stack(zeroshot_weights, dim=1)
- return zeroshot_weights
-
-
- def evaluate(model, dataloader, zeroshot_weights, name):
- model.eval()
- corrct = 0
- total_count = 0
- with jt.no_grad():
- print(f"\nTesting on {name}")
- bar = tqdm(dataloader)
- for i, batch in enumerate(bar):
- images, targets = batch
- total_count += len(images)
- image_features = model.encode_image(images)
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
- logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
- preds = jt.argmax(logits, dim=1)[0]
- corrct += jt.equal(preds, targets).sum().item()
- bar.set_description(f'{name} : {corrct}/{total_count} Acc: {corrct / total_count:.4f}')
- bar.update(1)
- bar.close()
- return corrct
-
-
- def main(best_model_path, weights_version):
- model, transform = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
- if best_model_path is not None:
- model.load_state_dict(jt.load(best_model_path))
- yellow_print('Loaded weights from {}'.format(best_model_path))
- else:
- yellow_print('weights not loaded! using pretrained weights')
- zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
-
- image_list, id_list = get_valid_dataset()
-
- valid_dataset = split_valid_dataset(image_list, id_list)
-
- batch_sizes = [256, 256, 512, 512]
- num_works = [4, 4, 8, 8]
-
- animal = TestSet(valid_dataset['animal']['image_list'],
- valid_dataset['animal']['id_list'],
- transform=transform,
- batch_size=batch_sizes[0],
- num_workers=num_works[0])
-
- caltech = TestSet(valid_dataset['caltech']['image_list'],
- valid_dataset['caltech']['id_list'],
- transform=transform,
- batch_size=batch_sizes[1],
- num_workers=num_works[1])
-
- food = TestSet(valid_dataset['food']['image_list'],
- valid_dataset['food']['id_list'],
- transform=transform,
- batch_size=batch_sizes[2],
- num_workers=num_works[2])
-
- dogs = TestSet(valid_dataset['dogs']['image_list'],
- valid_dataset['dogs']['id_list'],
- transform=transform,
- batch_size=batch_sizes[3],
- num_workers=num_works[3]
- )
- animal_total = len(valid_dataset['animal']['image_list'])
- caltech_total = len(valid_dataset['caltech']['image_list'])
- food_total = len(valid_dataset['food']['image_list'])
- dogs_total = len(valid_dataset['dogs']['image_list'])
-
- total = animal_total + caltech_total + food_total + dogs_total
- animal_correct = evaluate(model, animal, zeroshot_weights, 'animal')
- caltech_correct = evaluate(model, caltech, zeroshot_weights, 'caltech')
- food_correct = evaluate(model, food, zeroshot_weights, 'food')
- dogs_correct = evaluate(model, dogs, zeroshot_weights, 'dogs')
-
- correct_total = animal_correct + caltech_correct + food_correct + dogs_correct
-
- metrics = [animal_correct/ animal_total,
- caltech_correct/ caltech_total,
- food_correct/ food_total,
- dogs_correct/ dogs_total,
- correct_total / total]
-
- print('Average Acc: ', metrics[-1])
-
- return [round(acc, 4) for acc in metrics]
-
-
- if __name__ == '__main__':
- # 待测试的模型路径,是一个文件夹
- model_dir = '/jittor-competiiton/ckptFE/07-30/version_1'
-
- logger = setup_logger(model_dir, type_='test')
-
- # 需要测试的模型文件名, 如min_loss, 20, 30
-
- model_name = ['min_loss', 20, 50, 70, 90, 100]
-
- # 需要测试的提示词版本
- # 1. basic: a photo of
- # 2. custom:
- # 3. from imagenet
- test_weights_version = [1, 2, 3]
-
- table_header = ['Epoch', '提示词', 'Animal', 'Caltech', 'Food', 'Dogs', 'Total']
- table_data = []
- promot = {
- 1: 'basic',
- 2: 'custom',
- 3: 'imagenet'
- }
-
- for epoch in model_name:
- if isinstance(epoch, str):
- model_path = os.path.join(model_dir, 'min_loss.pth')
- elif isinstance(epoch, int):
- model_path = os.path.join(model_dir, f'epoch_{epoch}.pth')
- print(f'Testing with {model_path}')
- logger.info(f'Testing with {model_path}')
- for weights_version in test_weights_version:
- metrics = main(model_path, weights_version)
- metrics.insert(0, promot[weights_version])
- metrics.insert(0, epoch)
- table_data.append(metrics)
- print(tabulate(table_data, headers=table_header, tablefmt='fancy_grid'))
- logger.info(tabulate(table_data, headers=table_header, tablefmt='fancy_grid'))
|