| @@ -0,0 +1,107 @@ | |||||
| import jclip as clip | |||||
| def load_clip(freeze_version=0): | |||||
| """freeze part of CLIP model. | |||||
| Args: | |||||
| freeze_version (int, optional): _description_. Defaults to 0. | |||||
| Returns: | |||||
| _type_: _description_ | |||||
| """ | |||||
| assert freeze_version in [0, 1, 2, 3, 4, 5, 6] | |||||
| model, preprocess = clip.load("pretrained/ViT-B-32.pkl") | |||||
| model.train() | |||||
| def _freeze_version_1(): | |||||
| """冻结全部文本层""" | |||||
| model.token_embedding.weight.requires_grad = False | |||||
| model.positional_embedding.requires_grad = False | |||||
| model.text_projection.requires_grad = False | |||||
| model.ln_final.weight.requires_grad = False | |||||
| model.ln_final.bias.requires_grad = False | |||||
| for param in model.transformer.parameters(): | |||||
| param.requires_grad = False | |||||
| def _freeze_version_2(): | |||||
| """冻结全部图像层""" | |||||
| for param in model.visual.parameters(): | |||||
| param.requires_grad = False | |||||
| def _freeze_version_3(): | |||||
| """冻结浅层图像层""" | |||||
| _freeze_version_2() | |||||
| unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_post'] | |||||
| for name, param in model.visual.named_parameters(): | |||||
| if any(layer in name for layer in unfreeze_layers): | |||||
| param.requires_grad = True | |||||
| else: | |||||
| param.requires_grad = False | |||||
| def _freeze_version_4(): | |||||
| """冻结高层图像层""" | |||||
| _freeze_version_2() | |||||
| unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] | |||||
| for name, param in model.visual.named_parameters(): | |||||
| if any(layer in name for layer in unfreeze_layers): | |||||
| param.requires_grad = True | |||||
| else: | |||||
| param.requires_grad = False | |||||
| model.visual.class_embedding.requires_grad = True | |||||
| model.visual.positional_embedding.requires_grad = True | |||||
| model.visual.proj.requires_grad = True | |||||
| model.visual.conv1.weight.requires_grad = True | |||||
| model.visual.ln_pre.weight.requires_grad = True | |||||
| model.visual.ln_pre.bias.requires_grad = True | |||||
| def _freeze_version_5(): | |||||
| """冻结浅层文本层""" | |||||
| _freeze_version_1() | |||||
| unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_final'] | |||||
| for name, param in model.transformer.named_parameters(): | |||||
| if any(layer in name for layer in unfreeze_layers): | |||||
| param.requires_grad = True | |||||
| else: | |||||
| param.requires_grad = False | |||||
| def _freeze_version_6(): | |||||
| """冻结高层文本层""" | |||||
| _freeze_version_1() | |||||
| model.token_embedding.weight.requires_grad = True | |||||
| model.positional_embedding.requires_grad = True | |||||
| model.text_projection.requires_grad = True | |||||
| model.ln_final.weight.requires_grad = True | |||||
| model.ln_final.bias.requires_grad = True | |||||
| unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] | |||||
| for name, param in model.transformer.named_parameters(): | |||||
| if any(layer in name for layer in unfreeze_layers): | |||||
| param.requires_grad = True | |||||
| else: | |||||
| param.requires_grad = False | |||||
| freeze_functions = { | |||||
| 0: lambda: print('NO Freeze'), | |||||
| 1: _freeze_version_1, | |||||
| 2: _freeze_version_2, | |||||
| 3: _freeze_version_3, | |||||
| 4: _freeze_version_4, | |||||
| 5: _freeze_version_5, | |||||
| 6: _freeze_version_6 | |||||
| } | |||||
| freeze_functions[freeze_version]() | |||||
| return model, preprocess | |||||
| if __name__ == '__main__': | |||||
| model, preprocess = load_clip(freeze_version=0) | |||||
| for name, param in model.named_parameters(): | |||||
| print(name.ljust(60, '-'), param.requires_grad) | |||||
| # for name, param in model.named_parameters(): | |||||
| # print(name.ljust(60, '-'), param.requires_grad) | |||||