diff --git a/models.py b/models.py new file mode 100644 index 0000000..8d12f56 --- /dev/null +++ b/models.py @@ -0,0 +1,101 @@ +import jclip as clip + + +def load_clip(freeze_version=0): + """freeze part of CLIP model. + + Args: + freeze_version (int, optional): _description_. Defaults to 0. + + Returns: + _type_: _description_ + """ + assert freeze_version in [0, 1, 2, 3, 4, 5, 6] + model, preprocess = clip.load("pretrained/ViT-B-32.pkl") + model.train() + + def _freeze_version_1(): + """冻结全部文本层""" + model.token_embedding.weight.requires_grad = False + model.positional_embedding.requires_grad = False + model.text_projection.requires_grad = False + model.ln_final.weight.requires_grad = False + model.ln_final.bias.requires_grad = False + for param in model.transformer.parameters(): + param.requires_grad = False + + def _freeze_version_2(): + """冻结全部图像层""" + for param in model.visual.parameters(): + param.requires_grad = False + + def _freeze_version_3(): + """冻结浅层图像层""" + _freeze_version_2() + unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_post'] + for name, param in model.visual.named_parameters(): + if any(layer in name for layer in unfreeze_layers): + param.requires_grad = True + else: + param.requires_grad = False + + def _freeze_version_4(): + """冻结高层图像层""" + _freeze_version_2() + unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] + for name, param in model.visual.named_parameters(): + if any(layer in name for layer in unfreeze_layers): + param.requires_grad = True + else: + param.requires_grad = False + + model.visual.class_embedding.requires_grad = True + model.visual.positional_embedding.requires_grad = True + model.visual.proj.requires_grad = True + model.visual.conv1.weight.requires_grad = True + model.visual.ln_pre.weight.requires_grad = True + model.visual.ln_pre.bias.requires_grad = True + + def _freeze_version_5(): + """冻结浅层文本层""" + _freeze_version_1() + unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_final'] + for name, param in model.transformer.named_parameters(): + if any(layer in name for layer in unfreeze_layers): + param.requires_grad = True + else: + param.requires_grad = False + + def _freeze_version_6(): + """冻结高层文本层""" + _freeze_version_1() + model.token_embedding.weight.requires_grad = True + model.positional_embedding.requires_grad = True + model.text_projection.requires_grad = True + model.ln_final.weight.requires_grad = True + model.ln_final.bias.requires_grad = True + unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] + for name, param in model.transformer.named_parameters(): + if any(layer in name for layer in unfreeze_layers): + param.requires_grad = True + else: + param.requires_grad = False + + freeze_functions = { + 0: lambda: print('NO Freeze'), + 1: _freeze_version_1, + 2: _freeze_version_2, + 3: _freeze_version_3, + 4: _freeze_version_4, + 5: _freeze_version_5, + 6: _freeze_version_6 + } + freeze_functions[freeze_version]() + return model, preprocess + + +if __name__ == '__main__': + model, preprocess = load_clip(freeze_version=0) + + for name, param in model.named_parameters(): + print(name.ljust(60, '-'), param.requires_grad)