You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clip.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import hashlib
  2. import os
  3. import numpy as np
  4. import urllib
  5. import warnings
  6. from typing import Union, List
  7. import jittor as jt
  8. from tqdm import tqdm
  9. from .model import build_model
  10. from .simple_tokenizer import SimpleTokenizer as _Tokenizer
  11. from PIL import Image
  12. from jittor.transform import CenterCrop, ImageNormalize, Compose, _setup_size, to_pil_image, resize
  13. __all__ = ["available_models", "load", "tokenize"]
  14. _tokenizer = _Tokenizer()
  15. _MODELS = {
  16. "RN50":
  17. "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
  18. "RN101":
  19. "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
  20. "RN50x4":
  21. "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
  22. "RN50x16":
  23. "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
  24. "RN50x64":
  25. "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
  26. "ViT-B/32":
  27. "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
  28. "ViT-B/16":
  29. "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
  30. "ViT-L/14":
  31. "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
  32. "ViT-L/14@336px":
  33. "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
  34. }
  35. def _download(url: str, root: str):
  36. os.makedirs(root, exist_ok=True)
  37. filename = os.path.basename(url)
  38. expected_sha256 = url.split("/")[-2]
  39. download_target = os.path.join(root, filename)
  40. if os.path.exists(download_target) and not os.path.isfile(download_target):
  41. raise RuntimeError(
  42. f"{download_target} exists and is not a regular file")
  43. if os.path.isfile(download_target):
  44. if hashlib.sha256(open(download_target,
  45. "rb").read()).hexdigest() == expected_sha256:
  46. return download_target
  47. else:
  48. warnings.warn(
  49. f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
  50. )
  51. with urllib.request.urlopen(url) as source, open(download_target,
  52. "wb") as output:
  53. with tqdm(total=int(source.info().get("Content-Length")),
  54. ncols=80,
  55. unit='iB',
  56. unit_scale=True,
  57. unit_divisor=1024) as loop:
  58. while True:
  59. buffer = source.read(8192)
  60. if not buffer:
  61. break
  62. output.write(buffer)
  63. loop.update(len(buffer))
  64. if hashlib.sha256(open(download_target,
  65. "rb").read()).hexdigest() != expected_sha256:
  66. raise RuntimeError(
  67. "Model has been downloaded but the SHA256 checksum does not not match"
  68. )
  69. return download_target
  70. def _convert_image_to_rgb(image):
  71. return image.convert("RGB")
  72. def to_tensor(data):
  73. return jt.Var(data)
  74. class ImageToTensor(object):
  75. def __call__(self, input):
  76. input = np.asarray(input)
  77. if len(input.shape) < 3:
  78. input = np.expand_dims(input, -1)
  79. return to_tensor(input)
  80. class Resize:
  81. def __init__(self, size, mode=Image.BILINEAR):
  82. if isinstance(size, int):
  83. self.size = size
  84. else:
  85. self.size = _setup_size(
  86. size,
  87. error_msg="If size is a sequence, it should have 2 values")
  88. self.mode = mode
  89. def __call__(self, img: Image.Image):
  90. if not isinstance(img, Image.Image):
  91. img = to_pil_image(img)
  92. if isinstance(self.size, int):
  93. w, h = img.size
  94. short, long = (w, h) if w <= h else (h, w)
  95. if short == self.size:
  96. return img
  97. new_short, new_long = self.size, int(self.size * long / short)
  98. new_w, new_h = (new_short, new_long) if w <= h else (new_long,
  99. new_short)
  100. size = (new_h, new_w)
  101. return resize(img, size, self.mode)
  102. def _transform(n_px):
  103. return Compose([
  104. Resize(n_px, mode=Image.BICUBIC),
  105. CenterCrop(n_px), _convert_image_to_rgb,
  106. ImageNormalize((0.48145466, 0.4578275, 0.40821073),
  107. (0.26862954, 0.26130258, 0.27577711)),
  108. ImageToTensor()
  109. ])
  110. def available_models() -> List[str]:
  111. """Returns the names of available CLIP models"""
  112. return list(_MODELS.keys())
  113. def load(name, download_root=None):
  114. if name in _MODELS:
  115. model_path = _download(
  116. _MODELS[name], download_root
  117. or os.path.expanduser("~/.cache/clip"))
  118. elif os.path.isfile(name):
  119. model_path = name
  120. else:
  121. raise RuntimeError(
  122. f"Model {name} not found; available models = {available_models()}")
  123. # with open(model_path, 'rb') as opened_file:
  124. state_dict = jt.load(model_path)
  125. model = build_model(state_dict)
  126. return model, _transform(model.visual.input_resolution)
  127. def tokenize(texts: Union[str, List[str]],
  128. context_length: int = 77,
  129. truncate: bool = False):
  130. if isinstance(texts, str):
  131. texts = [texts]
  132. sot_token = _tokenizer.encoder["<|startoftext|>"]
  133. eot_token = _tokenizer.encoder["<|endoftext|>"]
  134. all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
  135. for text in texts]
  136. result = jt.zeros((len(all_tokens), context_length), dtype=jt.int64)
  137. for i, tokens in enumerate(all_tokens):
  138. if len(tokens) > context_length:
  139. if truncate:
  140. tokens = tokens[:context_length]
  141. tokens[-1] = eot_token
  142. else:
  143. raise RuntimeError(
  144. f"Input {texts[i]} is too long for context length {context_length}"
  145. )
  146. result[i, :len(tokens)] = jt.Var(tokens)
  147. return result

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)