import jclip as clip def load_clip(freeze_version=0): """freeze part of CLIP model. Args: freeze_version (int, optional): _description_. Defaults to 0. Returns: _type_: _description_ """ assert freeze_version in [0, 1, 2, 3, 4, 5, 6] model, preprocess = clip.load("pretrained/ViT-B-32.pkl") model.train() def _freeze_version_1(): """冻结全部文本层""" model.token_embedding.weight.requires_grad = False model.positional_embedding.requires_grad = False model.text_projection.requires_grad = False model.ln_final.weight.requires_grad = False model.ln_final.bias.requires_grad = False for param in model.transformer.parameters(): param.requires_grad = False def _freeze_version_2(): """冻结全部图像层""" for param in model.visual.parameters(): param.requires_grad = False def _freeze_version_3(): """冻结浅层图像层""" _freeze_version_2() unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_post'] for name, param in model.visual.named_parameters(): if any(layer in name for layer in unfreeze_layers): param.requires_grad = True else: param.requires_grad = False def _freeze_version_4(): """冻结高层图像层""" _freeze_version_2() unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] for name, param in model.visual.named_parameters(): if any(layer in name for layer in unfreeze_layers): param.requires_grad = True else: param.requires_grad = False model.visual.class_embedding.requires_grad = True model.visual.positional_embedding.requires_grad = True model.visual.proj.requires_grad = True model.visual.conv1.weight.requires_grad = True model.visual.ln_pre.weight.requires_grad = True model.visual.ln_pre.bias.requires_grad = True def _freeze_version_5(): """冻结浅层文本层""" _freeze_version_1() unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_final'] for name, param in model.transformer.named_parameters(): if any(layer in name for layer in unfreeze_layers): param.requires_grad = True else: param.requires_grad = False def _freeze_version_6(): """冻结高层文本层""" _freeze_version_1() model.token_embedding.weight.requires_grad = True model.positional_embedding.requires_grad = True model.text_projection.requires_grad = True model.ln_final.weight.requires_grad = True model.ln_final.bias.requires_grad = True unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] for name, param in model.transformer.named_parameters(): if any(layer in name for layer in unfreeze_layers): param.requires_grad = True else: param.requires_grad = False freeze_functions = { 0: lambda: print('NO Freeze'), 1: _freeze_version_1, 2: _freeze_version_2, 3: _freeze_version_3, 4: _freeze_version_4, 5: _freeze_version_5, 6: _freeze_version_6 } freeze_functions[freeze_version]() return model, preprocess if __name__ == '__main__': model, preprocess = load_clip(freeze_version=0) for name, param in model.named_parameters(): print(name.ljust(60, '-'), param.requires_grad)