Browse Source

ADD file via upload

master
BIT2024 1 year ago
parent
commit
95f6d18108
1 changed files with 284 additions and 0 deletions
  1. +284
    -0
      test_clip.py

+ 284
- 0
test_clip.py View File

@@ -0,0 +1,284 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import functools

from typing import List
import jittor as jt
import numpy as np
from tqdm import tqdm
from PIL import Image
import jclip as clip
from tabulate import tabulate
from template import templatesv1, templatesv2, templatesv3
from path_config import DATASET_ROOT_DIR
from utils.logger import setup_logger
from utils.color_print import yellow_print


jt.flags.use_cuda = 1


@functools.lru_cache()
def get_valid_dataset():
dataset_path = os.path.join(DATASET_ROOT_DIR, 'valid.lst')
img_label_pairs = open(dataset_path, 'r').read().splitlines()
image_list = [l.split(' ')[0] for l in img_label_pairs]
id_list = [int(l.split(' ')[1]) for l in img_label_pairs]
return image_list, id_list


def split_valid_dataset(image_list, id_list):
valid_dataset = {
'animal': {
'image_list': [],
'id_list': []
},
'caltech': {
'image_list': [],
'id_list': []
},
'dogs': {
'image_list': [],
'id_list': []
},
'food': {
'image_list': [],
'id_list': []
}
}
for image, label in zip(image_list, id_list):
if label < 52:
valid_dataset['animal']['image_list'].append(image)
valid_dataset['animal']['id_list'].append(label)
elif label < 143:
valid_dataset['caltech']['image_list'].append(image)
valid_dataset['caltech']['id_list'].append(label)
elif label < 244:
valid_dataset['food']['image_list'].append(image)
valid_dataset['food']['id_list'].append(label)
else:
valid_dataset['dogs']['image_list'].append(image)
valid_dataset['dogs']['id_list'].append(label)
return valid_dataset

class TestSet(jt.dataset.Dataset):
def __init__(self,
image_list,
id_list,
transform,
batch_size=256,
num_workers=8):
super(TestSet, self).__init__()
self.image_list = image_list
self.id_list = id_list
self.batch_size = batch_size
self.num_workers = num_workers
self.transform = transform
self.set_attrs(
batch_size=self.batch_size,
total_len=len(self.image_list),
num_workers=self.num_workers,
buffer_size=1024 * 1024 * 1024
)


def __getitem__(self, idx):
image_path = self.image_list[idx]
label = self.id_list[idx]
image = Image.open(f'{DATASET_ROOT_DIR}/{image_path}').convert('RGB')
if self.transform is not None:
image = self.transform(image)
img = np.asarray(image)
return img, label

@functools.lru_cache()
def get_classnames() -> List[str]:
"""映射:从类别(数字)到文本"""
with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file:
classes = file.read().splitlines()
map_dict = {}
id_list = [int(line.split(' ')[1]) for line in classes]
classnames = [line.split(' ')[0] for line in classes]
for idx, name in zip(id_list, classnames):
if idx < 52:
map_dict[idx] = ' '.join(name[7:].lower().split('_'))
elif idx < 143:
map_dict[idx] = ' '.join(name[12:].lower().split('_'))
elif idx < 244:
map_dict[idx] = ' '.join(name[9:].lower().split('_'))
else:
map_dict[idx] = ' '.join(name[8:].lower().split('_'))
return list(map_dict.values())


def zeroshot_classifier(model: clip.model.CLIP,
classnames: List[str] = get_classnames(),
weights_version: int = 1) -> jt.Var:
"""
使用 CLIP 模型进行零样本分类器的构建。

Args:
- model (clip.model.CLIP): 加载的 CLIP 模型实例。
- classnames (list): 包含所有类别名称的列表。
- templates (list): 包含模板字符串的列表,用于生成每个类别的文本输入。

Returns:
- torch.Tensor: 形状为 (embedding_size, num_classes) 的零样本权重张量。

注:
- 此函数假设模型已经在 GPU 上,并且模型可以处理字符串的 tokenization 和文本嵌入。
- 输出的张量将包含每个类别的平均嵌入向量,并进行了归一化处理。
"""
if weights_version == 1:
templates = templatesv1
elif weights_version == 2:
templates = templatesv2
elif weights_version == 3:
templates = templatesv3
else:
raise ValueError("weights_version must be 1, 2, or 3")
model.eval()
with jt.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames, desc='Extracting class embeddings'):
texts = [template.format(classname) for template in templates] #format with class
texts = clip.tokenize(texts) #tokenize
class_embeddings = model.encode_text(texts) #embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = jt.stack(zeroshot_weights, dim=1)
return zeroshot_weights


def evaluate(model, dataloader, zeroshot_weights, name):
model.eval()
corrct = 0
total_count = 0
with jt.no_grad():
print(f"\nTesting on {name}")
bar = tqdm(dataloader)
for i, batch in enumerate(bar):
images, targets = batch
total_count += len(images)
image_features = model.encode_image(images)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
preds = jt.argmax(logits, dim=1)[0]
corrct += jt.equal(preds, targets).sum().item()
bar.set_description(f'{name} : {corrct}/{total_count} Acc: {corrct / total_count:.4f}')
bar.update(1)
bar.close()
return corrct


def main(best_model_path, weights_version):
model, transform = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
if best_model_path is not None:
model.load_state_dict(jt.load(best_model_path))
yellow_print('Loaded weights from {}'.format(best_model_path))
else:
yellow_print('weights not loaded! using pretrained weights')
zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
image_list, id_list = get_valid_dataset()
valid_dataset = split_valid_dataset(image_list, id_list)
batch_sizes = [256, 256, 512, 512]
num_works = [4, 4, 8, 8]
animal = TestSet(valid_dataset['animal']['image_list'],
valid_dataset['animal']['id_list'],
transform=transform,
batch_size=batch_sizes[0],
num_workers=num_works[0])
caltech = TestSet(valid_dataset['caltech']['image_list'],
valid_dataset['caltech']['id_list'],
transform=transform,
batch_size=batch_sizes[1],
num_workers=num_works[1])
food = TestSet(valid_dataset['food']['image_list'],
valid_dataset['food']['id_list'],
transform=transform,
batch_size=batch_sizes[2],
num_workers=num_works[2])
dogs = TestSet(valid_dataset['dogs']['image_list'],
valid_dataset['dogs']['id_list'],
transform=transform,
batch_size=batch_sizes[3],
num_workers=num_works[3]
)
animal_total = len(valid_dataset['animal']['image_list'])
caltech_total = len(valid_dataset['caltech']['image_list'])
food_total = len(valid_dataset['food']['image_list'])
dogs_total = len(valid_dataset['dogs']['image_list'])
total = animal_total + caltech_total + food_total + dogs_total
animal_correct = evaluate(model, animal, zeroshot_weights, 'animal')
caltech_correct = evaluate(model, caltech, zeroshot_weights, 'caltech')
food_correct = evaluate(model, food, zeroshot_weights, 'food')
dogs_correct = evaluate(model, dogs, zeroshot_weights, 'dogs')
correct_total = animal_correct + caltech_correct + food_correct + dogs_correct
metrics = [animal_correct/ animal_total,
caltech_correct/ caltech_total,
food_correct/ food_total,
dogs_correct/ dogs_total,
correct_total / total]
print('Average Acc: ', metrics[-1])
return [round(acc, 4) for acc in metrics]
if __name__ == '__main__':
# 待测试的模型路径,是一个文件夹
model_dir = '/jittor-competiiton/ckptFE/07-30/version_1'
logger = setup_logger(model_dir, type_='test')
# 需要测试的模型文件名, 如min_loss, 20, 30

model_name = ['min_loss', 20, 50, 70, 90, 100]
# 需要测试的提示词版本
# 1. basic: a photo of
# 2. custom:
# 3. from imagenet
test_weights_version = [1, 2, 3]
table_header = ['Epoch', '提示词', 'Animal', 'Caltech', 'Food', 'Dogs', 'Total']
table_data = []
promot = {
1: 'basic',
2: 'custom',
3: 'imagenet'
}
for epoch in model_name:
if isinstance(epoch, str):
model_path = os.path.join(model_dir, 'min_loss.pth')
elif isinstance(epoch, int):
model_path = os.path.join(model_dir, f'epoch_{epoch}.pth')
print(f'Testing with {model_path}')
logger.info(f'Testing with {model_path}')
for weights_version in test_weights_version:
metrics = main(model_path, weights_version)
metrics.insert(0, promot[weights_version])
metrics.insert(0, epoch)
table_data.append(metrics)
print(tabulate(table_data, headers=table_header, tablefmt='fancy_grid'))
logger.info(tabulate(table_data, headers=table_header, tablefmt='fancy_grid'))

Loading…
Cancel
Save