You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clip.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import hashlib
  2. import os
  3. import numpy as np
  4. import urllib
  5. import warnings
  6. from typing import Union, List
  7. import jittor as jt
  8. from tqdm import tqdm
  9. from .model import build_model
  10. from .simple_tokenizer import SimpleTokenizer as _Tokenizer
  11. from PIL import Image
  12. from jittor.transform import CenterCrop, ImageNormalize, Compose, _setup_size, to_pil_image, resize
  13. __all__ = ["available_models", "load", "tokenize"]
  14. _tokenizer = _Tokenizer()
  15. _MODELS = {
  16. "RN50":
  17. "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
  18. "RN101":
  19. "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
  20. "RN50x4":
  21. "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
  22. "RN50x16":
  23. "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
  24. "RN50x64":
  25. "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
  26. "ViT-B/32":
  27. "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
  28. "ViT-B/16":
  29. "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
  30. "ViT-L/14":
  31. "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
  32. "ViT-L/14@336px":
  33. "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
  34. }
  35. def _download(url: str, root: str):
  36. os.makedirs(root, exist_ok=True)
  37. filename = os.path.basename(url)
  38. expected_sha256 = url.split("/")[-2]
  39. download_target = os.path.join(root, filename)
  40. if os.path.exists(download_target) and not os.path.isfile(download_target):
  41. raise RuntimeError(
  42. f"{download_target} exists and is not a regular file")
  43. if os.path.isfile(download_target):
  44. if hashlib.sha256(open(download_target,
  45. "rb").read()).hexdigest() == expected_sha256:
  46. return download_target
  47. else:
  48. warnings.warn(
  49. f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
  50. )
  51. with urllib.request.urlopen(url) as source, open(download_target,
  52. "wb") as output:
  53. with tqdm(total=int(source.info().get("Content-Length")),
  54. ncols=80,
  55. unit='iB',
  56. unit_scale=True,
  57. unit_divisor=1024) as loop:
  58. while True:
  59. buffer = source.read(8192)
  60. if not buffer:
  61. break
  62. output.write(buffer)
  63. loop.update(len(buffer))
  64. if hashlib.sha256(open(download_target,
  65. "rb").read()).hexdigest() != expected_sha256:
  66. raise RuntimeError(
  67. "Model has been downloaded but the SHA256 checksum does not not match"
  68. )
  69. return download_target
  70. def _convert_image_to_rgb(image):
  71. return image.convert("RGB")
  72. def to_tensor(data):
  73. return jt.Var(data)
  74. class ImageToTensor(object):
  75. def __call__(self, input):
  76. input = np.asarray(input)
  77. if len(input.shape) < 3:
  78. input = np.expand_dims(input, -1)
  79. return to_tensor(input)
  80. class Resize:
  81. def __init__(self, size, mode=Image.BILINEAR):
  82. if isinstance(size, int):
  83. self.size = size
  84. else:
  85. self.size = _setup_size(
  86. size,
  87. error_msg="If size is a sequence, it should have 2 values")
  88. self.mode = mode
  89. def __call__(self, img: Image.Image):
  90. if not isinstance(img, Image.Image):
  91. img = to_pil_image(img)
  92. if isinstance(self.size, int):
  93. w, h = img.size
  94. short, long = (w, h) if w <= h else (h, w)
  95. if short == self.size:
  96. return img
  97. new_short, new_long = self.size, int(self.size * long / short)
  98. new_w, new_h = (new_short, new_long) if w <= h else (new_long,
  99. new_short)
  100. size = (new_h, new_w)
  101. return resize(img, size, self.mode)
  102. def _transform(n_px):
  103. return Compose([
  104. Resize(n_px, mode=Image.BICUBIC),
  105. CenterCrop(n_px), _convert_image_to_rgb,
  106. ImageNormalize((0.48145466, 0.4578275, 0.40821073),
  107. (0.26862954, 0.26130258, 0.27577711)),
  108. ImageToTensor()
  109. ])
  110. def available_models() -> List[str]:
  111. """Returns the names of available CLIP models"""
  112. return list(_MODELS.keys())
  113. def load(name, download_root=None):
  114. if name in _MODELS:
  115. model_path = _download(
  116. _MODELS[name], download_root
  117. or os.path.expanduser("~/.cache/clip"))
  118. elif os.path.isfile(name):
  119. model_path = name
  120. else:
  121. raise RuntimeError(
  122. f"Model {name} not found; available models = {available_models()}")
  123. # with open(model_path, 'rb') as opened_file:
  124. state_dict = jt.load(model_path)
  125. model = build_model(state_dict)
  126. return model, _transform(model.visual.input_resolution)
  127. def tokenize(texts: Union[str, List[str]],
  128. context_length: int = 77,
  129. truncate: bool = False):
  130. if isinstance(texts, str):
  131. texts = [texts]
  132. sot_token = _tokenizer.encoder["<|startoftext|>"]
  133. eot_token = _tokenizer.encoder["<|endoftext|>"]
  134. all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
  135. for text in texts]
  136. result = jt.zeros((len(all_tokens), context_length), dtype=jt.int64)
  137. for i, tokens in enumerate(all_tokens):
  138. if len(tokens) > context_length:
  139. if truncate:
  140. tokens = tokens[:context_length]
  141. tokens[-1] = eot_token
  142. else:
  143. raise RuntimeError(
  144. f"Input {texts[i]} is too long for context length {context_length}"
  145. )
  146. result[i, :len(tokens)] = jt.Var(tokens)
  147. return result

首先冻结OpenAI官方预训练的ViT-B/32版本的CLIP模型中的全部图像层,再利用AdanBelief优化器训练模型,该优化器是Adan优化器和AdaBelief优化器的融合,在Adan优化器中融入"Belief"增强训练模型的泛化性能。

Contributors (1)