You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. # 指定使用1号显卡
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import jittor as jt
  5. from PIL import Image
  6. import numpy as np
  7. import jclip as clip
  8. from tqdm import tqdm
  9. from test_clip import zeroshot_classifier
  10. from path_config import TESTSET_ROOT_DIR
  11. jt.flags.use_cuda = 1
  12. class TestSet(jt.dataset.Dataset):
  13. def __init__(self,
  14. image_list,
  15. transform,
  16. batch_size=256,
  17. num_workers=8):
  18. super(TestSet, self).__init__()
  19. self.image_list = image_list
  20. self.batch_size = batch_size
  21. self.num_workers = num_workers
  22. self.transform = transform
  23. self.set_attrs(
  24. batch_size=self.batch_size,
  25. total_len=len(self.image_list),
  26. num_workers=self.num_workers
  27. )
  28. def __getitem__(self, idx):
  29. image_path = self.image_list[idx]
  30. label = image_path.split('/')[-1]
  31. image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB')
  32. if self.transform is not None:
  33. image = self.transform(image)
  34. img = np.asarray(image)
  35. return img, label
  36. def load_model(model_path):
  37. model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
  38. model.eval()
  39. model.load_state_dict(jt.load(model_path))
  40. return model, preprocess
  41. def predict(model, preprocess, weights_version=1):
  42. # image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ]
  43. image_paths = os.listdir(TESTSET_ROOT_DIR)
  44. dataloader = TestSet(image_paths, preprocess )
  45. zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
  46. result = []
  47. img_paths = []
  48. with jt.no_grad():
  49. for batch in tqdm(dataloader, desc='Preprocessing'):
  50. images, paths = batch
  51. image_features = model.encode_image(images)
  52. image_features = image_features / image_features.norm(dim=1, keepdim=True)
  53. logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
  54. _, predict_indices = jt.topk(logits, k=5, dim=1)
  55. img_paths += paths
  56. result += predict_indices.tolist()
  57. return result, img_paths
  58. def main(model_path, weights_version=1):
  59. model, preprocess = load_model(model_path)
  60. result, img_paths = predict(model, preprocess, weights_version)
  61. out_txt(result, img_paths)
  62. def out_txt(result, img_paths, file_name='result.txt'):
  63. with open(file_name, 'w') as f:
  64. for img, preds in zip(img_paths, result):
  65. f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n')
  66. print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}')
  67. if __name__ == "__main__":
  68. # 完整的模型路径
  69. model_path = '/jittor-competiiton/ckptFE/08-16/version_0/epoch_90.pth'
  70. main(model_path, weights_version=1)
  71. # # 提交系统测试0.7103
  72. # # 自测 0.7373 epoch_90.pth basic

首先冻结OpenAI官方预训练的ViT-B/32版本的CLIP模型中的全部图像层,再利用AdanBelief优化器训练模型,该优化器是Adan优化器和AdaBelief优化器的融合,在Adan优化器中融入"Belief"增强训练模型的泛化性能。

Contributors (1)