Browse Source

ADD file via upload

master
BIT2024 1 year ago
parent
commit
7cb92c3e4c
1 changed files with 85 additions and 0 deletions
  1. +85
    -0
      predict.py

+ 85
- 0
predict.py View File

@@ -0,0 +1,85 @@
import os

# 指定使用1号显卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import jittor as jt
from PIL import Image
import numpy as np
import jclip as clip
from tqdm import tqdm
from test_clip import zeroshot_classifier
from path_config import TESTSET_ROOT_DIR

jt.flags.use_cuda = 1

class TestSet(jt.dataset.Dataset):
def __init__(self,
image_list,
transform,
batch_size=256,
num_workers=8):
super(TestSet, self).__init__()
self.image_list = image_list
self.batch_size = batch_size
self.num_workers = num_workers
self.transform = transform
self.set_attrs(
batch_size=self.batch_size,
total_len=len(self.image_list),
num_workers=self.num_workers
)


def __getitem__(self, idx):
image_path = self.image_list[idx]
label = image_path.split('/')[-1]
image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB')
if self.transform is not None:
image = self.transform(image)
img = np.asarray(image)
return img, label
def load_model(model_path):
model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
model.eval()
model.load_state_dict(jt.load(model_path))
return model, preprocess

def predict(model, preprocess, weights_version=1):
# image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ]
image_paths = os.listdir(TESTSET_ROOT_DIR)
dataloader = TestSet(image_paths, preprocess )
zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
result = []
img_paths = []
with jt.no_grad():
for batch in tqdm(dataloader, desc='Preprocessing'):
images, paths = batch
image_features = model.encode_image(images)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
_, predict_indices = jt.topk(logits, k=5, dim=1)
img_paths += paths
result += predict_indices.tolist()
return result, img_paths

def main(model_path, weights_version=2):
model, preprocess = load_model(model_path)
result, img_paths = predict(model, preprocess, weights_version)
out_txt(result, img_paths)

def out_txt(result, img_paths, file_name='result.txt'):
with open(file_name, 'w') as f:
for img, preds in zip(img_paths, result):
f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n')
print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}')
if __name__ == "__main__":
# 完整的模型路径
model_path = '/jittor-competiiton/ckptFE/07-30/version_1/min_loss.pth'
main(model_path, weights_version=2)
# 提交系统测试0.6788

Loading…
Cancel
Save