Browse Source

ADD file via upload

master
BIT2024 1 year ago
parent
commit
72f6f2577c
1 changed files with 101 additions and 0 deletions
  1. +101
    -0
      models.py

+ 101
- 0
models.py View File

@@ -0,0 +1,101 @@
import jclip as clip


def load_clip(freeze_version=0):
"""freeze part of CLIP model.

Args:
freeze_version (int, optional): _description_. Defaults to 0.

Returns:
_type_: _description_
"""
assert freeze_version in [0, 1, 2, 3, 4, 5, 6]
model, preprocess = clip.load("pretrained/ViT-B-32.pkl")
model.train()
def _freeze_version_1():
"""冻结全部文本层"""
model.token_embedding.weight.requires_grad = False
model.positional_embedding.requires_grad = False
model.text_projection.requires_grad = False
model.ln_final.weight.requires_grad = False
model.ln_final.bias.requires_grad = False
for param in model.transformer.parameters():
param.requires_grad = False
def _freeze_version_2():
"""冻结全部图像层"""
for param in model.visual.parameters():
param.requires_grad = False
def _freeze_version_3():
"""冻结浅层图像层"""
_freeze_version_2()
unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_post']
for name, param in model.visual.named_parameters():
if any(layer in name for layer in unfreeze_layers):
param.requires_grad = True
else:
param.requires_grad = False
def _freeze_version_4():
"""冻结高层图像层"""
_freeze_version_2()
unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2']
for name, param in model.visual.named_parameters():
if any(layer in name for layer in unfreeze_layers):
param.requires_grad = True
else:
param.requires_grad = False
model.visual.class_embedding.requires_grad = True
model.visual.positional_embedding.requires_grad = True
model.visual.proj.requires_grad = True
model.visual.conv1.weight.requires_grad = True
model.visual.ln_pre.weight.requires_grad = True
model.visual.ln_pre.bias.requires_grad = True
def _freeze_version_5():
"""冻结浅层文本层"""
_freeze_version_1()
unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_final']
for name, param in model.transformer.named_parameters():
if any(layer in name for layer in unfreeze_layers):
param.requires_grad = True
else:
param.requires_grad = False
def _freeze_version_6():
"""冻结高层文本层"""
_freeze_version_1()
model.token_embedding.weight.requires_grad = True
model.positional_embedding.requires_grad = True
model.text_projection.requires_grad = True
model.ln_final.weight.requires_grad = True
model.ln_final.bias.requires_grad = True
unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2']
for name, param in model.transformer.named_parameters():
if any(layer in name for layer in unfreeze_layers):
param.requires_grad = True
else:
param.requires_grad = False
freeze_functions = {
0: lambda: print('NO Freeze'),
1: _freeze_version_1,
2: _freeze_version_2,
3: _freeze_version_3,
4: _freeze_version_4,
5: _freeze_version_5,
6: _freeze_version_6
}
freeze_functions[freeze_version]()
return model, preprocess


if __name__ == '__main__':
model, preprocess = load_clip(freeze_version=0)
for name, param in model.named_parameters():
print(name.ljust(60, '-'), param.requires_grad)

Loading…
Cancel
Save