You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. # 指定使用1号显卡
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import jittor as jt
  5. from PIL import Image
  6. import numpy as np
  7. import jclip as clip
  8. from tqdm import tqdm
  9. from test_clip import zeroshot_classifier
  10. from path_config import TESTSET_ROOT_DIR
  11. jt.flags.use_cuda = 1
  12. class TestSet(jt.dataset.Dataset):
  13. def __init__(self,
  14. image_list,
  15. transform,
  16. batch_size=256,
  17. num_workers=8):
  18. super(TestSet, self).__init__()
  19. self.image_list = image_list
  20. self.batch_size = batch_size
  21. self.num_workers = num_workers
  22. self.transform = transform
  23. self.set_attrs(
  24. batch_size=self.batch_size,
  25. total_len=len(self.image_list),
  26. num_workers=self.num_workers
  27. )
  28. def __getitem__(self, idx):
  29. image_path = self.image_list[idx]
  30. label = image_path.split('/')[-1]
  31. image = Image.open(f'{TESTSET_ROOT_DIR}/{image_path}').convert('RGB')
  32. if self.transform is not None:
  33. image = self.transform(image)
  34. img = np.asarray(image)
  35. return img, label
  36. def load_model(model_path):
  37. model, preprocess = clip.load("/jittor-competiiton/pretrained/ViT-B-32.pkl")
  38. model.eval()
  39. model.load_state_dict(jt.load(model_path))
  40. return model, preprocess
  41. def predict(model, preprocess, weights_version=1):
  42. # image_paths = [os.path.join(TESTSET_ROOT_DIR, img_path) for img_path in ]
  43. image_paths = os.listdir(TESTSET_ROOT_DIR)
  44. dataloader = TestSet(image_paths, preprocess )
  45. zeroshot_weights = zeroshot_classifier(model, weights_version=weights_version)
  46. result = []
  47. img_paths = []
  48. with jt.no_grad():
  49. for batch in tqdm(dataloader, desc='Preprocessing'):
  50. images, paths = batch
  51. image_features = model.encode_image(images)
  52. image_features = image_features / image_features.norm(dim=1, keepdim=True)
  53. logits = (100 * image_features @ zeroshot_weights).softmax(dim=-1)
  54. _, predict_indices = jt.topk(logits, k=5, dim=1)
  55. img_paths += paths
  56. result += predict_indices.tolist()
  57. return result, img_paths
  58. def main(model_path, weights_version=2):
  59. model, preprocess = load_model(model_path)
  60. result, img_paths = predict(model, preprocess, weights_version)
  61. out_txt(result, img_paths)
  62. def out_txt(result, img_paths, file_name='result.txt'):
  63. with open(file_name, 'w') as f:
  64. for img, preds in zip(img_paths, result):
  65. f.write(f'{img} '+ ' '.join([str(i) for i in preds]) + '\n')
  66. print(f'Result file generated!\nFile Path: {os.path.abspath(file_name)}')
  67. if __name__ == "__main__":
  68. # 完整的模型路径
  69. model_path = '/jittor-competiiton/ckptFE/07-30/version_1/min_loss.pth'
  70. main(model_path, weights_version=2)
  71. # 提交系统测试0.6788

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)