|
|
@@ -0,0 +1,284 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
|
|
import functools |
|
|
|
|
|
|
|
from typing import List |
|
|
|
import jittor as jt |
|
|
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
|
|
from PIL import Image |
|
|
|
import jclip as clip |
|
|
|
from tabulate import tabulate |
|
|
|
from template import templatesv1, templatesv2, templatesv3 |
|
|
|
from path_config import DATASET_ROOT_DIR |
|
|
|
from utils.logger import setup_logger |
|
|
|
from utils.color_print import yellow_print |
|
|
|
|
|
|
|
|
|
|
|
jt.flags.use_cuda = 1 |
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache() |
|
|
|
def get_valid_dataset(): |
|
|
|
dataset_path = os.path.join(DATASET_ROOT_DIR, 'valid.lst') |
|
|
|
img_label_pairs = open(dataset_path, 'r').read().splitlines() |
|
|
|
image_list = [l.split(' ')[0] for l in img_label_pairs] |
|
|
|
id_list = [int(l.split(' ')[1]) for l in img_label_pairs] |
|
|
|
return image_list, id_list |
|
|
|
|
|
|
|
|
|
|
|
def split_valid_dataset(image_list, id_list): |
|
|
|
valid_dataset = { |
|
|
|
'animal': { |
|
|
|
'image_list': [], |
|
|
|
'id_list': [] |
|
|
|
}, |
|
|
|
'caltech': { |
|
|
|
'image_list': [], |
|
|
|
'id_list': [] |
|
|
|
}, |
|
|
|
'dogs': { |
|
|
|
'image_list': [], |
|
|
|
'id_list': [] |
|
|
|
}, |
|
|
|
'food': { |
|
|
|
'image_list': [], |
|
|
|
'id_list': [] |
|
|
|
} |
|
|
|
} |
|
|
|
for image, label in zip(image_list, id_list): |
|
|
|
if label < 52: |
|
|
|
valid_dataset['animal']['image_list'].append(image) |
|
|
|
valid_dataset['animal']['id_list'].append(label) |
|
|
|
elif label < 143: |
|
|
|
valid_dataset['caltech']['image_list'].append(image) |
|
|
|
valid_dataset['caltech']['id_list'].append(label) |
|
|
|
elif label < 244: |
|
|
|
valid_dataset['food']['image_list'].append(image) |
|
|
|
valid_dataset['food']['id_list'].append(label) |
|
|
|
else: |
|
|
|
valid_dataset['dogs']['image_list'].append(image) |
|
|
|
valid_dataset['dogs']['id_list'].append(label) |
|
|
|
|
|
|
|
return valid_dataset |
|
|
|
|
|
|
|
class TestSet(jt.dataset.Dataset): |
|
|
|
def __init__(self, |
|
|
|
image_list, |
|
|
|
id_list, |
|
|
|
transform, |
|
|
|
batch_size=256, |
|
|
|
num_workers=8): |
|
|
|
super(TestSet, self).__init__() |
|
|
|
self.image_list = image_list |
|
|
|
self.id_list = id_list |
|
|
|
self.batch_size = batch_size |
|
|
|
self.num_workers = num_workers |
|
|
|
self.transform = transform |
|
|
|
|
|
|
|
self.set_attrs( |
|
|
|
batch_size=self.batch_size, |
|
|
|
total_len=len(self.image_list), |
|
|
|
num_workers=self.num_workers, |
|
|
|
buffer_size=1024 * 1024 * 1024 |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
image_path = self.image_list[idx] |
|
|
|
label = self.id_list[idx] |
|
|
|
image = Image.open(f'{DATASET_ROOT_DIR}/{image_path}').convert('RGB') |
|
|
|
if self.transform is not None: |
|
|
|
image = self.transform(image) |
|
|
|
img = np.asarray(image) |
|
|
|
return img, label |
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache() |
|
|
|
def get_classnames() -> List[str]: |
|
|
|
"""映射:从类别(数字)到文本""" |
|
|
|
with open(os.path.join(DATASET_ROOT_DIR, 'classes.txt'), 'r') as file: |
|
|
|
classes = file.read().splitlines() |
|
|
|
map_dict = {} |
|
|
|
id_list = [int(line.split(' ')[1]) for line in classes] |
|
|
|
classnames = [line.split(' ')[0] for line in classes] |
|
|
|
for idx, name in zip(id_list, classnames): |
|
|
|
if idx < 52: |
|
|
|
map_dict[idx] = ' '.join(name[7:].lower().split('_')) |
|
|
|
elif idx < 143: |
|
|
|
map_dict[idx] = ' '.join(name[12:].lower().split('_')) |
|
|
|
elif idx < 244: |
|
|
|
map_dict[idx] = ' '.join(name[9:].lower().split('_')) |
|
|
|
else: |
|
|
|
map_dict[idx] = ' '.join(name[8:].lower().split('_')) |
|
|
|
return list(map_dict.values()) |
|
|
|
|
|
|
|
|
|
|
|
def zeroshot_classifier(model: clip.model.CLIP, |
|
|
|
classnames: List[str] = get_classnames(), |
|
|
|
weights_version: int = 1) -> jt.Var: |
|
|
|
""" |
|
|
|
使用 CLIP 模型进行零样本分类器的构建。 |
|
|
|
|
|
|
|
Args: |
|
|
|
- model (clip.model.CLIP): 加载的 CLIP 模型实例。 |
|
|
|
- classnames (list): 包含所有类别名称的列表。 |
|
|
|
- templates (list): 包含模板字符串的列表,用于生成每个类别的文本输入。 |
|
|
|
|
|
|
|
Returns: |
|
|
|
- torch.Tensor: 形状为 (embedding_size, num_classes) 的零样本权重张量。 |
|
|
|
|
|
|
|
注: |
|
|
|
- 此函数假设模型已经在 GPU 上,并且模型可以处理字符串的 tokenization 和文本嵌入。 |
|
|
|
- 输出的张量将包含每个类别的平均嵌入向量,并进行了归一化处理。 |
|
|
|
""" |
|
|
|
|
|
|
|
if weights_version == 1: |
|
|
|
templates = templatesv1 |
|
|
|
elif weights_version == 2: |
|
|
|
templates = templatesv2 |
|
|
|
elif weights_version == 3: |
|
|
|
templates = templatesv3 |
|
|
|
else: |
|
|
|
raise ValueError("weights_version must be 1, 2, or 3") |
|
|
|
|
|
|
|
model.eval() |
|
|
|
with jt.no_grad(): |
|
|
|
zeroshot_weights = [] |
|
|
|
for classname in tqdm(classnames, desc='Extracting class embeddings'): |
|
|
|
texts = [template.format(classname) for template in templates] #format with class |
|
|
|
texts = clip.tokenize(texts) #tokenize |
|
|
|
class_embeddings = model.encode_text(texts) #embed with text encoder |
|
|
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
class_embedding = class_embeddings.mean(dim=0) |
|
|
|
class_embedding /= class_embedding.norm() |
|
|
|
zeroshot_weights.append(class_embedding) |
|
|
|
zeroshot_weights = jt.stack(zeroshot_weights, dim=1) |
|
|
|
return zeroshot_weights |
|
|
|
|
|
|
|
|
|
|
|
def evaluate(model, dataloader, zeroshot_weights, name): |
|
|
|
model.eval() |
|
|
|
corrct = 0 |
|
|
|
total_count = 0 |
|
|
|
with jt.no_grad(): |
|
|
|
print(f"\nTesting on {name}") |
|
|
|
bar = tqdm(dataloader) |
|
|
|
for i, batch in enumerate(bar): |
|
|
|
images, targets = batch |
|
|
|
total_count += len(images) |
|
|
|
image_features = model.encode_image(images) |
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
|
|
logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) |
|
|
|
preds = jt.argmax(logits, dim=1)[0] |
|
|
|
corrct += jt.equal(preds, targets).sum().item() |
|
|
|
bar.set_description(f'{name} : {corrct}/{total_count} Acc: {corrct / total_count:.4f}') |
|
|
|
bar.update(1) |
|
|
|
bar.close() |
|
|
|
return corrct |
|
|
|
|
|
|
|
|
|
|
|
def main(best_model_path, weights_version): |
|
|
|
model, transform = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl") |
|
|
|
if best_model_path is not None: |
|
|
|
model.load_state_dict(jt.load(best_model_path)) |
|
|
|
yellow_print('Loaded weights from {}'.format(best_model_path)) |
|
|
|
else: |
|
|
|
yellow_print('weights not loaded! using pretrained weights') |
|
|
|
zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version) |
|
|
|
|
|
|
|
image_list, id_list = get_valid_dataset() |
|
|
|
|
|
|
|
valid_dataset = split_valid_dataset(image_list, id_list) |
|
|
|
|
|
|
|
batch_sizes = [256, 256, 512, 512] |
|
|
|
num_works = [4, 4, 8, 8] |
|
|
|
|
|
|
|
animal = TestSet(valid_dataset['animal']['image_list'], |
|
|
|
valid_dataset['animal']['id_list'], |
|
|
|
transform=transform, |
|
|
|
batch_size=batch_sizes[0], |
|
|
|
num_workers=num_works[0]) |
|
|
|
|
|
|
|
caltech = TestSet(valid_dataset['caltech']['image_list'], |
|
|
|
valid_dataset['caltech']['id_list'], |
|
|
|
transform=transform, |
|
|
|
batch_size=batch_sizes[1], |
|
|
|
num_workers=num_works[1]) |
|
|
|
|
|
|
|
food = TestSet(valid_dataset['food']['image_list'], |
|
|
|
valid_dataset['food']['id_list'], |
|
|
|
transform=transform, |
|
|
|
batch_size=batch_sizes[2], |
|
|
|
num_workers=num_works[2]) |
|
|
|
|
|
|
|
dogs = TestSet(valid_dataset['dogs']['image_list'], |
|
|
|
valid_dataset['dogs']['id_list'], |
|
|
|
transform=transform, |
|
|
|
batch_size=batch_sizes[3], |
|
|
|
num_workers=num_works[3] |
|
|
|
) |
|
|
|
animal_total = len(valid_dataset['animal']['image_list']) |
|
|
|
caltech_total = len(valid_dataset['caltech']['image_list']) |
|
|
|
food_total = len(valid_dataset['food']['image_list']) |
|
|
|
dogs_total = len(valid_dataset['dogs']['image_list']) |
|
|
|
|
|
|
|
total = animal_total + caltech_total + food_total + dogs_total |
|
|
|
animal_correct = evaluate(model, animal, zeroshot_weights, 'animal') |
|
|
|
caltech_correct = evaluate(model, caltech, zeroshot_weights, 'caltech') |
|
|
|
food_correct = evaluate(model, food, zeroshot_weights, 'food') |
|
|
|
dogs_correct = evaluate(model, dogs, zeroshot_weights, 'dogs') |
|
|
|
|
|
|
|
correct_total = animal_correct + caltech_correct + food_correct + dogs_correct |
|
|
|
|
|
|
|
metrics = [animal_correct/ animal_total, |
|
|
|
caltech_correct/ caltech_total, |
|
|
|
food_correct/ food_total, |
|
|
|
dogs_correct/ dogs_total, |
|
|
|
correct_total / total] |
|
|
|
|
|
|
|
print('Average Acc: ', metrics[-1]) |
|
|
|
|
|
|
|
return [round(acc, 4) for acc in metrics] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
# 待测试的模型路径,是一个文件夹 |
|
|
|
model_dir = '/jittor-competiiton/ckptFE/07-30/version_1' |
|
|
|
|
|
|
|
logger = setup_logger(model_dir, type_='test') |
|
|
|
|
|
|
|
# 需要测试的模型文件名, 如min_loss, 20, 30 |
|
|
|
|
|
|
|
model_name = ['min_loss', 20, 50, 70, 90, 100] |
|
|
|
|
|
|
|
# 需要测试的提示词版本 |
|
|
|
# 1. basic: a photo of |
|
|
|
# 2. custom: |
|
|
|
# 3. from imagenet |
|
|
|
test_weights_version = [1, 2, 3] |
|
|
|
|
|
|
|
table_header = ['Epoch', '提示词', 'Animal', 'Caltech', 'Food', 'Dogs', 'Total'] |
|
|
|
table_data = [] |
|
|
|
promot = { |
|
|
|
1: 'basic', |
|
|
|
2: 'custom', |
|
|
|
3: 'imagenet' |
|
|
|
} |
|
|
|
|
|
|
|
for epoch in model_name: |
|
|
|
if isinstance(epoch, str): |
|
|
|
model_path = os.path.join(model_dir, 'min_loss.pth') |
|
|
|
elif isinstance(epoch, int): |
|
|
|
model_path = os.path.join(model_dir, f'epoch_{epoch}.pth') |
|
|
|
print(f'Testing with {model_path}') |
|
|
|
logger.info(f'Testing with {model_path}') |
|
|
|
for weights_version in test_weights_version: |
|
|
|
metrics = main(model_path, weights_version) |
|
|
|
metrics.insert(0, promot[weights_version]) |
|
|
|
metrics.insert(0, epoch) |
|
|
|
table_data.append(metrics) |
|
|
|
print(tabulate(table_data, headers=table_header, tablefmt='fancy_grid')) |
|
|
|
logger.info(tabulate(table_data, headers=table_header, tablefmt='fancy_grid')) |