| @@ -0,0 +1,85 @@ | |||||
| import os | |||||
| # 指定使用1号显卡 | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
| import jittor as jt | |||||
| from PIL import Image | |||||
| import numpy as np | |||||
| import jclip as clip | |||||
| from tqdm import tqdm | |||||
| from test_clip import zeroshot_classifier | |||||
| from path_config import TESTSET_ROOT_DIR | |||||
| jt.flags.use_cuda = 1 | |||||
| class TestSet(jt.dataset.Dataset): | |||||
| def __init__(self, | |||||
| image_list, | |||||
| transform, | |||||
| batch_size=256, | |||||
| num_workers=8): | |||||
| super(TestSet, self).__init__() | |||||
| self.image_list = image_list | |||||
| self.batch_size = batch_size | |||||
| self.num_workers = num_workers | |||||
| self.transform = transform | |||||
| self.set_attrs( | |||||
| batch_size=self.batch_size, | |||||
| total_len=len(self.image_list), | |||||
| num_workers=self.num_workers | |||||
| ) | |||||
| def __getitem__(self, idx): | |||||
| image_path = self.image_list[idx] | |||||
| label = image_path.split('/')[-1] | |||||
| image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB') | |||||
| if self.transform is not None: | |||||
| image = self.transform(image) | |||||
| img = np.asarray(image) | |||||
| return img, label | |||||
| def load_model(model_path): | |||||
| model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl") | |||||
| model.eval() | |||||
| model.load_state_dict(jt.load(model_path)) | |||||
| return model, preprocess | |||||
| def predict(model, preprocess, weights_version=1): | |||||
| # image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ] | |||||
| image_paths = os.listdir(TESTSET_ROOT_DIR) | |||||
| dataloader = TestSet(image_paths, preprocess ) | |||||
| zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version) | |||||
| result = [] | |||||
| img_paths = [] | |||||
| with jt.no_grad(): | |||||
| for batch in tqdm(dataloader, desc='Preprocessing'): | |||||
| images, paths = batch | |||||
| image_features = model.encode_image(images) | |||||
| image_features = image_features / image_features.norm(dim=1, keepdim=True) | |||||
| logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1) | |||||
| _, predict_indices = jt.topk(logits, k=5, dim=1) | |||||
| img_paths += paths | |||||
| result += predict_indices.tolist() | |||||
| return result, img_paths | |||||
| def main(model_path, weights_version=2): | |||||
| model, preprocess = load_model(model_path) | |||||
| result, img_paths = predict(model, preprocess, weights_version) | |||||
| out_txt(result, img_paths) | |||||
| def out_txt(result, img_paths, file_name='result.txt'): | |||||
| with open(file_name, 'w') as f: | |||||
| for img, preds in zip(img_paths, result): | |||||
| f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n') | |||||
| print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}') | |||||
| if __name__ == "__main__": | |||||
| # 完整的模型路径 | |||||
| model_path = '/jittor-competiiton/ckptFE/07-30/version_1/min_loss.pth' | |||||
| main(model_path, weights_version=2) | |||||
| # 提交系统测试0.6788 | |||||