You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from typing import Tuple, Union
  2. import numpy as np
  3. import jittor as jt
  4. from jittor import nn, init
  5. from .mha import MultiheadAttention
  6. def normal_(module, mean=0, std=1, bias=0):
  7. if hasattr(module, 'weight') and module.weight is not None:
  8. init.gauss_(module.weight, mean, std)
  9. if hasattr(module, 'bias') and isinstance(
  10. module.bias, jt.Var) and module.bias is not None:
  11. init.constant_(module.bias, bias)
  12. class LayerNorm(nn.LayerNorm):
  13. def execute(self, x):
  14. ret = super().execute(x)
  15. return ret
  16. class QuickGELU(nn.Module):
  17. def execute(self, x):
  18. return x * jt.sigmoid(1.702 * x)
  19. class MLP(nn.Module):
  20. def __init__(self, d_model):
  21. super().__init__()
  22. self.c_fc = nn.Linear(d_model, d_model * 4)
  23. self.gelu = QuickGELU()
  24. self.c_proj = nn.Linear(d_model * 4, d_model)
  25. def execute(self, x):
  26. return self.c_proj(self.gelu(self.c_fc(x)))
  27. class ResidualAttentionBlock(nn.Module):
  28. def __init__(self, d_model, n_head, attn_mask):
  29. super().__init__()
  30. self.attn = MultiheadAttention(d_model, n_head)
  31. self.ln_1 = LayerNorm(d_model)
  32. self.mlp = MLP(d_model)
  33. self.ln_2 = LayerNorm(d_model)
  34. self.attn_mask = attn_mask
  35. def attention(self, x):
  36. self.attn_mask = self.attn_mask.to(
  37. dtype=x.dtype) if self.attn_mask is not None else None
  38. return self.attn(x, x, x, need_weights=False,
  39. attn_mask=self.attn_mask)[0]
  40. def execute(self, x):
  41. x = x + self.attention(self.ln_1(x))
  42. x = x + self.mlp(self.ln_2(x))
  43. return x
  44. class Transformer(nn.Module):
  45. def __init__(self, width, layers, heads, attn_mask=None):
  46. super().__init__()
  47. self.width = width
  48. self.layers = layers
  49. self.resblocks = nn.Sequential(*[
  50. ResidualAttentionBlock(width, heads, attn_mask)
  51. for _ in range(layers)
  52. ])
  53. def execute(self, x):
  54. return self.resblocks(x)
  55. class VisionTransformer(nn.Module):
  56. def __init__(self, input_resolution: int, patch_size: int, width: int,
  57. layers: int, heads: int, output_dim: int):
  58. super().__init__()
  59. self.input_resolution = input_resolution
  60. self.output_dim = output_dim
  61. self.conv1 = nn.Conv2d(in_channels=3,
  62. out_channels=width,
  63. kernel_size=patch_size,
  64. stride=patch_size,
  65. bias=False)
  66. scale = width**-0.5
  67. self.class_embedding = scale * jt.randn((width))
  68. self.positional_embedding = scale * jt.randn(
  69. ((input_resolution // patch_size)**2 + 1, width))
  70. self.ln_pre = LayerNorm(width)
  71. self.transformer = Transformer(width, layers, heads)
  72. self.ln_post = LayerNorm(width)
  73. self.proj = scale * jt.randn((width, output_dim))
  74. def execute(self, x):
  75. x = self.conv1(x) # shape = [*, width, grid, grid]
  76. x = x.reshape(x.shape[0], x.shape[1],
  77. -1) # shape = [*, width, grid ** 2]
  78. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  79. x = jt.concat([
  80. self.class_embedding.to(x.dtype) + jt.zeros(
  81. (x.shape[0], 1, x.shape[-1]), dtype=x.dtype), x
  82. ],
  83. dim=1) # shape = [*, grid ** 2 + 1, width]
  84. x = x + self.positional_embedding.to(x.dtype)
  85. x = self.ln_pre(x)
  86. x = x.permute(1, 0, 2) # NLD -> LND
  87. x = self.transformer(x)
  88. x = x.permute(1, 0, 2) # LND -> NLD
  89. x = self.ln_post(x[:, 0, :])
  90. if self.proj is not None:
  91. x = x @ self.proj
  92. return x
  93. class CLIP(nn.Module):
  94. def __init__(
  95. self,
  96. embed_dim: int,
  97. # vision
  98. image_resolution: int,
  99. vision_layers: Union[Tuple[int, int, int, int], int],
  100. vision_width: int,
  101. vision_patch_size: int,
  102. # text
  103. context_length: int,
  104. vocab_size: int,
  105. transformer_width: int,
  106. transformer_heads: int,
  107. transformer_layers: int):
  108. super().__init__()
  109. self.context_length = context_length
  110. vision_heads = vision_width // 64
  111. self.visual = VisionTransformer(input_resolution=image_resolution,
  112. patch_size=vision_patch_size,
  113. width=vision_width,
  114. layers=vision_layers,
  115. heads=vision_heads,
  116. output_dim=embed_dim)
  117. self.transformer = Transformer(width=transformer_width,
  118. layers=transformer_layers,
  119. heads=transformer_heads,
  120. attn_mask=self.build_attention_mask())
  121. self.vocab_size = vocab_size
  122. self.token_embedding = nn.Embedding(vocab_size, transformer_width)
  123. self.positional_embedding = jt.empty(
  124. (self.context_length, transformer_width))
  125. self.ln_final = LayerNorm(transformer_width)
  126. self.text_projection = jt.empty((transformer_width, embed_dim))
  127. self.logit_scale = jt.ones([]) * np.log(1 / 0.07)
  128. self.initialize_parameters()
  129. def initialize_parameters(self):
  130. normal_(self.token_embedding.weight, std=0.02)
  131. normal_(self.positional_embedding, std=0.01)
  132. proj_std = (self.transformer.width**-0.5) * (
  133. (2 * self.transformer.layers)**-0.5)
  134. attn_std = self.transformer.width**-0.5
  135. fc_std = (2 * self.transformer.width)**-0.5
  136. for block in self.transformer.resblocks:
  137. normal_(block.attn.in_proj_weight, std=attn_std)
  138. normal_(block.attn.out_proj.weight, std=proj_std)
  139. normal_(block.mlp.c_fc.weight, std=fc_std)
  140. normal_(block.mlp.c_proj.weight, std=proj_std)
  141. if self.text_projection is not None:
  142. normal_(self.text_projection, std=self.transformer.width**-0.5)
  143. def build_attention_mask(self):
  144. mask = jt.empty((self.context_length, self.context_length))
  145. mask.fill_(float("-inf"))
  146. mask = jt.triu(mask, 1) # zero out the lower diagonal
  147. return mask
  148. @property
  149. def dtype(self):
  150. return self.visual.conv1.weight.dtype
  151. def encode_image(self, image):
  152. return self.visual(image)
  153. def encode_text(self, text):
  154. x = self.token_embedding(text)
  155. x = x + self.positional_embedding
  156. x = x.permute(1, 0, 2) # NLD -> LND
  157. x = self.transformer(x)
  158. x = x.permute(1, 0, 2) # LND -> NLD
  159. x = self.ln_final(x)
  160. # x.shape = [batch_size, n_ctx, transformer.width]
  161. # take features from the eot embedding (eot_token is the highest number in each sequence)
  162. x = x[jt.arange(x.shape[0]),
  163. text.argmax(dim=-1)[0]] @ self.text_projection
  164. return x
  165. def execute(self, image, text):
  166. image_features = self.encode_image(image)
  167. text_features = self.encode_text(text)
  168. # normalized features
  169. image_features = image_features / image_features.norm(dim=1,
  170. keepdim=True)
  171. text_features = text_features / text_features.norm(dim=1, keepdim=True)
  172. # cosine similarity as logits
  173. logit_scale = self.logit_scale.exp()
  174. logits_per_image = logit_scale * image_features @ text_features.t()
  175. logits_per_text = logits_per_image.t()
  176. # shape = [global_batch_size, global_batch_size]
  177. return logits_per_image, logits_per_text
  178. def build_model(state_dict: dict):
  179. vit = "visual.proj" in state_dict
  180. if vit:
  181. vision_width = state_dict["visual.conv1.weight"].shape[0]
  182. vision_layers = len([
  183. k for k in state_dict.keys()
  184. if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
  185. ])
  186. vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
  187. grid_size = round(
  188. (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
  189. image_resolution = vision_patch_size * grid_size
  190. else:
  191. counts: list = [
  192. len(
  193. set(
  194. k.split(".")[2] for k in state_dict
  195. if k.startswith(f"visual.layer{b}")))
  196. for b in [1, 2, 3, 4]
  197. ]
  198. vision_layers = tuple(counts)
  199. vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
  200. output_width = round(
  201. (state_dict["visual.attnpool.positional_embedding"].shape[0] -
  202. 1)**0.5)
  203. vision_patch_size = None
  204. assert output_width**2 + 1 == state_dict[
  205. "visual.attnpool.positional_embedding"].shape[0]
  206. image_resolution = output_width * 32
  207. embed_dim = state_dict["text_projection"].shape[1]
  208. context_length = state_dict["positional_embedding"].shape[0]
  209. vocab_size = state_dict["token_embedding.weight"].shape[0]
  210. transformer_width = state_dict["ln_final.weight"].shape[0]
  211. transformer_heads = transformer_width // 64
  212. transformer_layers = len(
  213. set(
  214. k.split(".")[2] for k in state_dict
  215. if k.startswith("transformer.resblocks")))
  216. model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
  217. vision_patch_size, context_length, vocab_size,
  218. transformer_width, transformer_heads, transformer_layers)
  219. for key in ["input_resolution", "context_length", "vocab_size"]:
  220. if key in state_dict:
  221. del state_dict[key]
  222. model.load_parameters(state_dict)
  223. return model.eval()

首先冻结OpenAI官方预训练的ViT-B/32版本的CLIP模型中的全部图像层,再利用AdanBelief优化器训练模型,该优化器是Adan优化器和AdaBelief优化器的融合,在Adan优化器中融入"Belief"增强训练模型的泛化性能。

Contributors (1)