You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import jclip as clip
  2. def load_clip(freeze_version=0):
  3. """freeze part of CLIP model.
  4. Args:
  5. freeze_version (int, optional): _description_. Defaults to 0.
  6. Returns:
  7. _type_: _description_
  8. """
  9. assert freeze_version in [0, 1, 2, 3, 4, 5, 6]
  10. model, preprocess = clip.load("pretrained/ViT-B-32.pkl")
  11. model.train()
  12. def _freeze_version_1():
  13. """冻结全部文本层"""
  14. model.token_embedding.weight.requires_grad = False
  15. model.positional_embedding.requires_grad = False
  16. model.text_projection.requires_grad = False
  17. model.ln_final.weight.requires_grad = False
  18. model.ln_final.bias.requires_grad = False
  19. for param in model.transformer.parameters():
  20. param.requires_grad = False
  21. def _freeze_version_2():
  22. """冻结全部图像层"""
  23. for param in model.visual.parameters():
  24. param.requires_grad = False
  25. def _freeze_version_3():
  26. """冻结浅层图像层"""
  27. _freeze_version_2()
  28. unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_post']
  29. for name, param in model.visual.named_parameters():
  30. if any(layer in name for layer in unfreeze_layers):
  31. param.requires_grad = True
  32. else:
  33. param.requires_grad = False
  34. def _freeze_version_4():
  35. """冻结高层图像层"""
  36. _freeze_version_2()
  37. unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2']
  38. for name, param in model.visual.named_parameters():
  39. if any(layer in name for layer in unfreeze_layers):
  40. param.requires_grad = True
  41. else:
  42. param.requires_grad = False
  43. model.visual.class_embedding.requires_grad = True
  44. model.visual.positional_embedding.requires_grad = True
  45. model.visual.proj.requires_grad = True
  46. model.visual.conv1.weight.requires_grad = True
  47. model.visual.ln_pre.weight.requires_grad = True
  48. model.visual.ln_pre.bias.requires_grad = True
  49. def _freeze_version_5():
  50. """冻结浅层文本层"""
  51. _freeze_version_1()
  52. unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_final']
  53. for name, param in model.transformer.named_parameters():
  54. if any(layer in name for layer in unfreeze_layers):
  55. param.requires_grad = True
  56. else:
  57. param.requires_grad = False
  58. def _freeze_version_6():
  59. """冻结高层文本层"""
  60. _freeze_version_1()
  61. model.token_embedding.weight.requires_grad = True
  62. model.positional_embedding.requires_grad = True
  63. model.text_projection.requires_grad = True
  64. model.ln_final.weight.requires_grad = True
  65. model.ln_final.bias.requires_grad = True
  66. unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2']
  67. for name, param in model.transformer.named_parameters():
  68. if any(layer in name for layer in unfreeze_layers):
  69. param.requires_grad = True
  70. else:
  71. param.requires_grad = False
  72. freeze_functions = {
  73. 0: lambda: print('NO Freeze'),
  74. 1: _freeze_version_1,
  75. 2: _freeze_version_2,
  76. 3: _freeze_version_3,
  77. 4: _freeze_version_4,
  78. 5: _freeze_version_5,
  79. 6: _freeze_version_6
  80. }
  81. freeze_functions[freeze_version]()
  82. return model, preprocess
  83. if __name__ == '__main__':
  84. model, preprocess = load_clip(freeze_version=0)
  85. for name, param in model.named_parameters():
  86. print(name.ljust(60, '-'), param.requires_grad)

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)