|
- import os
-
- # 指定使用1号显卡
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
-
- import jittor as jt
- from PIL import Image
- import numpy as np
- import jclip as clip
- from tqdm import tqdm
- from test_clip import zeroshot_classifier
- from path_config import TESTSET_ROOT_DIR
-
- jt.flags.use_cuda = 1
-
- class TestSet(jt.dataset.Dataset):
- def __init__(self,
- image_list,
- transform,
- batch_size=256,
- num_workers=8):
- super(TestSet, self).__init__()
- self.image_list = image_list
- self.batch_size = batch_size
- self.num_workers = num_workers
- self.transform = transform
-
- self.set_attrs(
- batch_size=self.batch_size,
- total_len=len(self.image_list),
- num_workers=self.num_workers
- )
-
-
- def __getitem__(self, idx):
- image_path = self.image_list[idx]
- label = image_path.split('/')[-1]
- image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB')
- if self.transform is not None:
- image = self.transform(image)
- img = np.asarray(image)
- return img, label
-
- def load_model(model_path):
- model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
- model.eval()
- model.load_state_dict(jt.load(model_path))
- return model, preprocess
-
- def predict(model, preprocess, weights_version=1):
- # image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ]
- image_paths = os.listdir(TESTSET_ROOT_DIR)
- dataloader = TestSet(image_paths, preprocess )
- zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
- result = []
- img_paths = []
-
- with jt.no_grad():
- for batch in tqdm(dataloader, desc='Preprocessing'):
- images, paths = batch
- image_features = model.encode_image(images)
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
- logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
- _, predict_indices = jt.topk(logits, k=5, dim=1)
- img_paths += paths
- result += predict_indices.tolist()
- return result, img_paths
-
- def main(model_path, weights_version=2):
- model, preprocess = load_model(model_path)
- result, img_paths = predict(model, preprocess, weights_version)
- out_txt(result, img_paths)
-
- def out_txt(result, img_paths, file_name='result.txt'):
- with open(file_name, 'w') as f:
- for img, preds in zip(img_paths, result):
- f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n')
- print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}')
-
- if __name__ == "__main__":
- # 完整的模型路径
- model_path = '/jittor-competiiton/ckptFE/07-30/version_1/min_loss.pth'
- main(model_path, weights_version=2)
- # 提交系统测试0.6788
-
|