|
|
@@ -0,0 +1,101 @@ |
|
|
|
import jclip as clip |
|
|
|
|
|
|
|
|
|
|
|
def load_clip(freeze_version=0): |
|
|
|
"""freeze part of CLIP model. |
|
|
|
|
|
|
|
Args: |
|
|
|
freeze_version (int, optional): _description_. Defaults to 0. |
|
|
|
|
|
|
|
Returns: |
|
|
|
_type_: _description_ |
|
|
|
""" |
|
|
|
assert freeze_version in [0, 1, 2, 3, 4, 5, 6] |
|
|
|
model, preprocess = clip.load("pretrained/ViT-B-32.pkl") |
|
|
|
model.train() |
|
|
|
|
|
|
|
def _freeze_version_1(): |
|
|
|
"""冻结全部文本层""" |
|
|
|
model.token_embedding.weight.requires_grad = False |
|
|
|
model.positional_embedding.requires_grad = False |
|
|
|
model.text_projection.requires_grad = False |
|
|
|
model.ln_final.weight.requires_grad = False |
|
|
|
model.ln_final.bias.requires_grad = False |
|
|
|
for param in model.transformer.parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
def _freeze_version_2(): |
|
|
|
"""冻结全部图像层""" |
|
|
|
for param in model.visual.parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
def _freeze_version_3(): |
|
|
|
"""冻结浅层图像层""" |
|
|
|
_freeze_version_2() |
|
|
|
unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_post'] |
|
|
|
for name, param in model.visual.named_parameters(): |
|
|
|
if any(layer in name for layer in unfreeze_layers): |
|
|
|
param.requires_grad = True |
|
|
|
else: |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
def _freeze_version_4(): |
|
|
|
"""冻结高层图像层""" |
|
|
|
_freeze_version_2() |
|
|
|
unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] |
|
|
|
for name, param in model.visual.named_parameters(): |
|
|
|
if any(layer in name for layer in unfreeze_layers): |
|
|
|
param.requires_grad = True |
|
|
|
else: |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
model.visual.class_embedding.requires_grad = True |
|
|
|
model.visual.positional_embedding.requires_grad = True |
|
|
|
model.visual.proj.requires_grad = True |
|
|
|
model.visual.conv1.weight.requires_grad = True |
|
|
|
model.visual.ln_pre.weight.requires_grad = True |
|
|
|
model.visual.ln_pre.bias.requires_grad = True |
|
|
|
|
|
|
|
def _freeze_version_5(): |
|
|
|
"""冻结浅层文本层""" |
|
|
|
_freeze_version_1() |
|
|
|
unfreeze_layers = ['resblocks.9', 'resblocks.10', 'resblocks.11', 'ln_final'] |
|
|
|
for name, param in model.transformer.named_parameters(): |
|
|
|
if any(layer in name for layer in unfreeze_layers): |
|
|
|
param.requires_grad = True |
|
|
|
else: |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
def _freeze_version_6(): |
|
|
|
"""冻结高层文本层""" |
|
|
|
_freeze_version_1() |
|
|
|
model.token_embedding.weight.requires_grad = True |
|
|
|
model.positional_embedding.requires_grad = True |
|
|
|
model.text_projection.requires_grad = True |
|
|
|
model.ln_final.weight.requires_grad = True |
|
|
|
model.ln_final.bias.requires_grad = True |
|
|
|
unfreeze_layers = ['resblocks.0', 'resblocks.1', 'resblocks.2'] |
|
|
|
for name, param in model.transformer.named_parameters(): |
|
|
|
if any(layer in name for layer in unfreeze_layers): |
|
|
|
param.requires_grad = True |
|
|
|
else: |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
freeze_functions = { |
|
|
|
0: lambda: print('NO Freeze'), |
|
|
|
1: _freeze_version_1, |
|
|
|
2: _freeze_version_2, |
|
|
|
3: _freeze_version_3, |
|
|
|
4: _freeze_version_4, |
|
|
|
5: _freeze_version_5, |
|
|
|
6: _freeze_version_6 |
|
|
|
} |
|
|
|
freeze_functions[freeze_version]() |
|
|
|
return model, preprocess |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
model, preprocess = load_clip(freeze_version=0) |
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
|
print(name.ljust(60, '-'), param.requires_grad) |