| @@ -0,0 +1,85 @@ | |||
| import os | |||
| # 指定使用1号显卡 | |||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||
| import jittor as jt | |||
| from PIL import Image | |||
| import numpy as np | |||
| import jclip as clip | |||
| from tqdm import tqdm | |||
| from test_clip import zeroshot_classifier | |||
| from path_config import TESTSET_ROOT_DIR | |||
| jt.flags.use_cuda = 1 | |||
| class TestSet(jt.dataset.Dataset): | |||
| def __init__(self, | |||
| image_list, | |||
| transform, | |||
| batch_size=256, | |||
| num_workers=8): | |||
| super(TestSet, self).__init__() | |||
| self.image_list = image_list | |||
| self.batch_size = batch_size | |||
| self.num_workers = num_workers | |||
| self.transform = transform | |||
| self.set_attrs( | |||
| batch_size=self.batch_size, | |||
| total_len=len(self.image_list), | |||
| num_workers=self.num_workers | |||
| ) | |||
| def __getitem__(self, idx): | |||
| image_path = self.image_list[idx] | |||
| label = image_path.split('/')[-1] | |||
| image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB') | |||
| if self.transform is not None: | |||
| image = self.transform(image) | |||
| img = np.asarray(image) | |||
| return img, label | |||
| def load_model(model_path): | |||
| model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl") | |||
| model.eval() | |||
| model.load_state_dict(jt.load(model_path)) | |||
| return model, preprocess | |||
| def predict(model, preprocess, weights_version=1): | |||
| # image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ] | |||
| image_paths = os.listdir(TESTSET_ROOT_DIR) | |||
| dataloader = TestSet(image_paths, preprocess ) | |||
| zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version) | |||
| result = [] | |||
| img_paths = [] | |||
| with jt.no_grad(): | |||
| for batch in tqdm(dataloader, desc='Preprocessing'): | |||
| images, paths = batch | |||
| image_features = model.encode_image(images) | |||
| image_features = image_features / image_features.norm(dim=1, keepdim=True) | |||
| logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) | |||
| _, predict_indices = jt.topk(logits, k=5, dim=1) | |||
| img_paths += paths | |||
| result += predict_indices.tolist() | |||
| return result, img_paths | |||
| def main(model_path, weights_version=2): | |||
| model, preprocess = load_model(model_path) | |||
| result, img_paths = predict(model, preprocess, weights_version) | |||
| out_txt(result, img_paths) | |||
| def out_txt(result, img_paths, file_name='result.txt'): | |||
| with open(file_name, 'w') as f: | |||
| for img, preds in zip(img_paths, result): | |||
| f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n') | |||
| print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}') | |||
| if __name__ == "__main__": | |||
| # 完整的模型路径 | |||
| model_path = '/jittor-competiiton/ckptFE/07-30/version_1/min_loss.pth' | |||
| main(model_path, weights_version=2) | |||
| # 提交系统测试0.6788 | |||