From 7cb92c3e4ca3d7aec0c1843a965293f5a32daa7f Mon Sep 17 00:00:00 2001 From: BIT2024 Date: Tue, 20 Aug 2024 14:53:55 +0800 Subject: [PATCH] ADD file via upload --- predict.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 predict.py diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..86be01d --- /dev/null +++ b/predict.py @@ -0,0 +1,85 @@ +import os + +# 指定使用1号显卡 +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +import jittor as jt +from PIL import Image +import numpy as np +import jclip as clip +from tqdm import tqdm +from test_clip import zeroshot_classifier +from path_config import TESTSET_ROOT_DIR + +jt.flags.use_cuda = 1 + +class TestSet(jt.dataset.Dataset): + def __init__(self, + image_list, + transform, + batch_size=256, + num_workers=8): + super(TestSet, self).__init__() + self.image_list = image_list + self.batch_size = batch_size + self.num_workers = num_workers + self.transform = transform + + self.set_attrs( + batch_size=self.batch_size, + total_len=len(self.image_list), + num_workers=self.num_workers + ) + + + def __getitem__(self, idx): + image_path = self.image_list[idx] + label = image_path.split('/')[-1] + image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB') + if self.transform is not None: + image = self.transform(image) + img = np.asarray(image) + return img, label + +def load_model(model_path): + model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl") + model.eval() + model.load_state_dict(jt.load(model_path)) + return model, preprocess + +def predict(model, preprocess, weights_version=1): + # image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ] + image_paths = os.listdir(TESTSET_ROOT_DIR) + dataloader = TestSet(image_paths, preprocess ) + zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version) + result = [] + img_paths = [] + + with jt.no_grad(): + for batch in tqdm(dataloader, desc='Preprocessing'): + images, paths = batch + image_features = model.encode_image(images) + image_features = image_features / image_features.norm(dim=1, keepdim=True) + logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) + _, predict_indices = jt.topk(logits, k=5, dim=1) + img_paths += paths + result += predict_indices.tolist() + return result, img_paths + +def main(model_path, weights_version=2): + model, preprocess = load_model(model_path) + result, img_paths = predict(model, preprocess, weights_version) + out_txt(result, img_paths) + +def out_txt(result, img_paths, file_name='result.txt'): + with open(file_name, 'w') as f: + for img, preds in zip(img_paths, result): + f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n') + print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}') + +if __name__ == "__main__": + # 完整的模型路径 + model_path = '/jittor-competiiton/ckptFE/07-30/version_1/min_loss.pth' + main(model_path, weights_version=2) +# 提交系统测试0.6788 + \ No newline at end of file