You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from typing import Tuple, Union
  2. import numpy as np
  3. import jittor as jt
  4. from jittor import nn, init
  5. from .mha import MultiheadAttention
  6. def normal_(module, mean=0, std=1, bias=0):
  7. if hasattr(module, 'weight') and module.weight is not None:
  8. init.gauss_(module.weight, mean, std)
  9. if hasattr(module, 'bias') and isinstance(
  10. module.bias, jt.Var) and module.bias is not None:
  11. init.constant_(module.bias, bias)
  12. class LayerNorm(nn.LayerNorm):
  13. def execute(self, x):
  14. ret = super().execute(x)
  15. return ret
  16. class QuickGELU(nn.Module):
  17. def execute(self, x):
  18. return x * jt.sigmoid(1.702 * x)
  19. class MLP(nn.Module):
  20. def __init__(self, d_model):
  21. super().__init__()
  22. self.c_fc = nn.Linear(d_model, d_model * 4)
  23. self.gelu = QuickGELU()
  24. self.c_proj = nn.Linear(d_model * 4, d_model)
  25. def execute(self, x):
  26. return self.c_proj(self.gelu(self.c_fc(x)))
  27. class ResidualAttentionBlock(nn.Module):
  28. def __init__(self, d_model, n_head, attn_mask):
  29. super().__init__()
  30. self.attn = MultiheadAttention(d_model, n_head)
  31. self.ln_1 = LayerNorm(d_model)
  32. self.mlp = MLP(d_model)
  33. self.ln_2 = LayerNorm(d_model)
  34. self.attn_mask = attn_mask
  35. def attention(self, x):
  36. self.attn_mask = self.attn_mask.to(
  37. dtype=x.dtype) if self.attn_mask is not None else None
  38. return self.attn(x, x, x, need_weights=False,
  39. attn_mask=self.attn_mask)[0]
  40. def execute(self, x):
  41. x = x + self.attention(self.ln_1(x))
  42. x = x + self.mlp(self.ln_2(x))
  43. return x
  44. class Transformer(nn.Module):
  45. def __init__(self, width, layers, heads, attn_mask=None):
  46. super().__init__()
  47. self.width = width
  48. self.layers = layers
  49. self.resblocks = nn.Sequential(*[
  50. ResidualAttentionBlock(width, heads, attn_mask)
  51. for _ in range(layers)
  52. ])
  53. def execute(self, x):
  54. return self.resblocks(x)
  55. class VisionTransformer(nn.Module):
  56. def __init__(self, input_resolution: int, patch_size: int, width: int,
  57. layers: int, heads: int, output_dim: int):
  58. super().__init__()
  59. self.input_resolution = input_resolution
  60. self.output_dim = output_dim
  61. self.conv1 = nn.Conv2d(in_channels=3,
  62. out_channels=width,
  63. kernel_size=patch_size,
  64. stride=patch_size,
  65. bias=False)
  66. scale = width**-0.5
  67. self.class_embedding = scale * jt.randn((width))
  68. self.positional_embedding = scale * jt.randn(
  69. ((input_resolution // patch_size)**2 + 1, width))
  70. self.ln_pre = LayerNorm(width)
  71. self.transformer = Transformer(width, layers, heads)
  72. self.ln_post = LayerNorm(width)
  73. self.proj = scale * jt.randn((width, output_dim))
  74. def execute(self, x):
  75. x = self.conv1(x) # shape = [*, width, grid, grid]
  76. x = x.reshape(x.shape[0], x.shape[1],
  77. -1) # shape = [*, width, grid ** 2]
  78. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  79. x = jt.concat([
  80. self.class_embedding.to(x.dtype) + jt.zeros(
  81. (x.shape[0], 1, x.shape[-1]), dtype=x.dtype), x
  82. ],
  83. dim=1) # shape = [*, grid ** 2 + 1, width]
  84. x = x + self.positional_embedding.to(x.dtype)
  85. x = self.ln_pre(x)
  86. x = x.permute(1, 0, 2) # NLD -> LND
  87. x = self.transformer(x)
  88. x = x.permute(1, 0, 2) # LND -> NLD
  89. x = self.ln_post(x[:, 0, :])
  90. if self.proj is not None:
  91. x = x @ self.proj
  92. return x
  93. class CLIP(nn.Module):
  94. def __init__(
  95. self,
  96. embed_dim: int,
  97. # vision
  98. image_resolution: int,
  99. vision_layers: Union[Tuple[int, int, int, int], int],
  100. vision_width: int,
  101. vision_patch_size: int,
  102. # text
  103. context_length: int,
  104. vocab_size: int,
  105. transformer_width: int,
  106. transformer_heads: int,
  107. transformer_layers: int):
  108. super().__init__()
  109. self.context_length = context_length
  110. vision_heads = vision_width // 64
  111. self.visual = VisionTransformer(input_resolution=image_resolution,
  112. patch_size=vision_patch_size,
  113. width=vision_width,
  114. layers=vision_layers,
  115. heads=vision_heads,
  116. output_dim=embed_dim)
  117. self.transformer = Transformer(width=transformer_width,
  118. layers=transformer_layers,
  119. heads=transformer_heads,
  120. attn_mask=self.build_attention_mask())
  121. self.vocab_size = vocab_size
  122. self.token_embedding = nn.Embedding(vocab_size, transformer_width)
  123. self.positional_embedding = jt.empty(
  124. (self.context_length, transformer_width))
  125. self.ln_final = LayerNorm(transformer_width)
  126. self.text_projection = jt.empty((transformer_width, embed_dim))
  127. self.logit_scale = jt.ones([]) * np.log(1 / 0.07)
  128. self.initialize_parameters()
  129. def initialize_parameters(self):
  130. normal_(self.token_embedding.weight, std=0.02)
  131. normal_(self.positional_embedding, std=0.01)
  132. proj_std = (self.transformer.width**-0.5) * (
  133. (2 * self.transformer.layers)**-0.5)
  134. attn_std = self.transformer.width**-0.5
  135. fc_std = (2 * self.transformer.width)**-0.5
  136. for block in self.transformer.resblocks:
  137. normal_(block.attn.in_proj_weight, std=attn_std)
  138. normal_(block.attn.out_proj.weight, std=proj_std)
  139. normal_(block.mlp.c_fc.weight, std=fc_std)
  140. normal_(block.mlp.c_proj.weight, std=proj_std)
  141. if self.text_projection is not None:
  142. normal_(self.text_projection, std=self.transformer.width**-0.5)
  143. def build_attention_mask(self):
  144. mask = jt.empty((self.context_length, self.context_length))
  145. mask.fill_(float("-inf"))
  146. mask = jt.triu(mask, 1) # zero out the lower diagonal
  147. return mask
  148. @property
  149. def dtype(self):
  150. return self.visual.conv1.weight.dtype
  151. def encode_image(self, image):
  152. return self.visual(image)
  153. def encode_text(self, text):
  154. x = self.token_embedding(text)
  155. x = x + self.positional_embedding
  156. x = x.permute(1, 0, 2) # NLD -> LND
  157. x = self.transformer(x)
  158. x = x.permute(1, 0, 2) # LND -> NLD
  159. x = self.ln_final(x)
  160. # x.shape = [batch_size, n_ctx, transformer.width]
  161. # take features from the eot embedding (eot_token is the highest number in each sequence)
  162. x = x[jt.arange(x.shape[0]),
  163. text.argmax(dim=-1)[0]] @ self.text_projection
  164. return x
  165. def execute(self, image, text):
  166. image_features = self.encode_image(image)
  167. text_features = self.encode_text(text)
  168. # normalized features
  169. image_features = image_features / image_features.norm(dim=1,
  170. keepdim=True)
  171. text_features = text_features / text_features.norm(dim=1, keepdim=True)
  172. # cosine similarity as logits
  173. logit_scale = self.logit_scale.exp()
  174. logits_per_image = logit_scale * image_features @ text_features.t()
  175. logits_per_text = logits_per_image.t()
  176. # shape = [global_batch_size, global_batch_size]
  177. return logits_per_image, logits_per_text
  178. def build_model(state_dict: dict):
  179. vit = "visual.proj" in state_dict
  180. if vit:
  181. vision_width = state_dict["visual.conv1.weight"].shape[0]
  182. vision_layers = len([
  183. k for k in state_dict.keys()
  184. if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
  185. ])
  186. vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
  187. grid_size = round(
  188. (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
  189. image_resolution = vision_patch_size * grid_size
  190. else:
  191. counts: list = [
  192. len(
  193. set(
  194. k.split(".")[2] for k in state_dict
  195. if k.startswith(f"visual.layer{b}")))
  196. for b in [1, 2, 3, 4]
  197. ]
  198. vision_layers = tuple(counts)
  199. vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
  200. output_width = round(
  201. (state_dict["visual.attnpool.positional_embedding"].shape[0] -
  202. 1)**0.5)
  203. vision_patch_size = None
  204. assert output_width**2 + 1 == state_dict[
  205. "visual.attnpool.positional_embedding"].shape[0]
  206. image_resolution = output_width * 32
  207. embed_dim = state_dict["text_projection"].shape[1]
  208. context_length = state_dict["positional_embedding"].shape[0]
  209. vocab_size = state_dict["token_embedding.weight"].shape[0]
  210. transformer_width = state_dict["ln_final.weight"].shape[0]
  211. transformer_heads = transformer_width // 64
  212. transformer_layers = len(
  213. set(
  214. k.split(".")[2] for k in state_dict
  215. if k.startswith("transformer.resblocks")))
  216. model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
  217. vision_patch_size, context_length, vocab_size,
  218. transformer_width, transformer_heads, transformer_layers)
  219. for key in ["input_resolution", "context_length", "vocab_size"]:
  220. if key in state_dict:
  221. del state_dict[key]
  222. model.load_parameters(state_dict)
  223. return model.eval()

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)