|
|
@@ -349,11 +349,13 @@ class CLIP(nn.Module): |
|
|
|
text_num_hidden_layers: int, |
|
|
|
text_type_vocab_size: int, |
|
|
|
tokenizer: FullTokenizer, |
|
|
|
# vision_head_width, added this param for ViT-H |
|
|
|
vision_head_width: int = 64, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
if isinstance(vision_layers, (tuple, list)): |
|
|
|
vision_heads = vision_width * 32 // 64 |
|
|
|
vision_heads = vision_width * 32 // vision_head_width |
|
|
|
self.visual = ModifiedResNet( |
|
|
|
layers=vision_layers, |
|
|
|
output_dim=embed_dim, |
|
|
@@ -361,7 +363,7 @@ class CLIP(nn.Module): |
|
|
|
input_resolution=image_resolution, |
|
|
|
width=vision_width) |
|
|
|
else: |
|
|
|
vision_heads = vision_width // 64 |
|
|
|
vision_heads = vision_width // vision_head_width |
|
|
|
self.visual = VisualTransformer( |
|
|
|
input_resolution=image_resolution, |
|
|
|
patch_size=vision_patch_size, |
|
|
|