Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10572842master
@@ -389,6 +389,7 @@ class Preprocessors(object): | |||||
# multi-modal preprocessor | # multi-modal preprocessor | ||||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ||||
clip_preprocessor = 'clip-preprocessor' | |||||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | ||||
# science preprocessor | # science preprocessor | ||||
@@ -428,6 +429,8 @@ class Metrics(object): | |||||
image_inpainting_metric = 'image-inpainting-metric' | image_inpainting_metric = 'image-inpainting-metric' | ||||
# metric for ocr | # metric for ocr | ||||
NED = 'ned' | NED = 'ned' | ||||
# metric for cross-modal retrieval | |||||
inbatch_recall = 'inbatch_recall' | |||||
# metric for referring-video-object-segmentation task | # metric for referring-video-object-segmentation task | ||||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | ||||
@@ -474,6 +477,9 @@ class Hooks(object): | |||||
# Compression | # Compression | ||||
SparsityHook = 'SparsityHook' | SparsityHook = 'SparsityHook' | ||||
# CLIP logit_scale clamp | |||||
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' | |||||
class LR_Schedulers(object): | class LR_Schedulers(object): | ||||
"""learning rate scheduler is defined here | """learning rate scheduler is defined here | ||||
@@ -24,6 +24,7 @@ class MetricKeys(object): | |||||
ROUGE_1 = 'rouge-1' | ROUGE_1 = 'rouge-1' | ||||
ROUGE_L = 'rouge-l' | ROUGE_L = 'rouge-l' | ||||
NED = 'ned' # ocr metric | NED = 'ned' # ocr metric | ||||
BatchAcc = 'inbatch_t2i_recall_at_1' | |||||
task_default_metrics = { | task_default_metrics = { | ||||
@@ -0,0 +1,55 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Dict | |||||
import numpy as np | |||||
import torch | |||||
from modelscope.metainfo import Metrics | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.utils.registry import default_group | |||||
from .base import Metric | |||||
from .builder import METRICS, MetricKeys | |||||
@METRICS.register_module( | |||||
group_key=default_group, module_name=Metrics.inbatch_recall) | |||||
class InbatchRecallMetric(Metric): | |||||
"""The metric computation class for in-batch retrieval classes. | |||||
This metric class calculates in-batch image recall@1 for each input batch. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.inbatch_t2i_hitcnts = [] | |||||
self.batch_sizes = [] | |||||
def add(self, outputs: Dict, inputs: Dict): | |||||
image_features = outputs[OutputKeys.IMG_EMBEDDING] | |||||
text_features = outputs[OutputKeys.TEXT_EMBEDDING] | |||||
assert type(image_features) == torch.Tensor and type( | |||||
text_features) == torch.Tensor | |||||
with torch.no_grad(): | |||||
logits_per_image = image_features @ text_features.t() | |||||
logits_per_text = logits_per_image.t() | |||||
batch_size = logits_per_image.shape[0] | |||||
ground_truth = torch.arange(batch_size).long() | |||||
ground_truth = ground_truth.to(image_features.device) | |||||
inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth | |||||
).sum().float().item() | |||||
self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt) | |||||
self.batch_sizes.append(batch_size) | |||||
def evaluate(self): | |||||
assert len(self.inbatch_t2i_hitcnts) == len( | |||||
self.batch_sizes) and len(self.batch_sizes) > 0 | |||||
return { | |||||
MetricKeys.BatchAcc: | |||||
sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes) | |||||
} |
@@ -15,15 +15,13 @@ | |||||
import os | import os | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import Any, Dict, Iterable, List, Tuple, Union | |||||
from typing import Any, Dict, Tuple, Union | |||||
import json | import json | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from PIL import Image | |||||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||||
from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
@@ -506,21 +504,6 @@ def convert_weights(model: nn.Module): | |||||
model.apply(_convert_weights_to_fp16) | model.apply(_convert_weights_to_fp16) | ||||
def _convert_to_rgb(image): | |||||
return image.convert('RGB') | |||||
def image_transform(image_size=224): | |||||
transform = Compose([ | |||||
_convert_to_rgb, | |||||
Resize((image_size, image_size)), | |||||
ToTensor(), | |||||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711)), | |||||
]) | |||||
return transform | |||||
@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | ||||
class CLIPForMultiModalEmbedding(TorchModel): | class CLIPForMultiModalEmbedding(TorchModel): | ||||
@@ -540,72 +523,40 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||||
with open(vision_model_config_file, | with open(vision_model_config_file, | ||||
'r') as fv, open(text_model_config_file, 'r') as ft: | 'r') as fv, open(text_model_config_file, 'r') as ft: | ||||
model_info = json.load(fv) | |||||
self.model_info = json.load(fv) | |||||
for k, v in json.load(ft).items(): | for k, v in json.load(ft).items(): | ||||
model_info[k] = v | |||||
# image preprocess | |||||
self.img_preprocess = image_transform(model_info['image_resolution']) | |||||
self.model_info[k] = v | |||||
# text tokenizer | |||||
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | ||||
self.tokenizer = FullTokenizer(vocab_file=vocab_file) | self.tokenizer = FullTokenizer(vocab_file=vocab_file) | ||||
# initialize the model | # initialize the model | ||||
self.clip_model = CLIP(**model_info, tokenizer=self.tokenizer) | |||||
self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer) | |||||
convert_weights(self.clip_model) | convert_weights(self.clip_model) | ||||
# restore the pretrained weight | # restore the pretrained weight | ||||
checkpoint = torch.load( | checkpoint = torch.load( | ||||
f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') | f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') | ||||
sd = checkpoint['state_dict'] | |||||
sd = checkpoint[ | |||||
'state_dict'] if 'state_dict' in checkpoint else checkpoint | |||||
if next(iter(sd.items()))[0].startswith('module'): | if next(iter(sd.items()))[0].startswith('module'): | ||||
sd = {k[len('module.'):]: v for k, v in sd.items()} | sd = {k[len('module.'):]: v for k, v in sd.items()} | ||||
# support the finetuned model | |||||
if next(iter(sd.items()))[0].startswith('clip_model'): | |||||
sd = {k[len('clip_model.'):]: v for k, v in sd.items()} | |||||
self.clip_model.load_state_dict(sd) | self.clip_model.load_state_dict(sd) | ||||
self.clip_model.eval() | self.clip_model.eval() | ||||
# place the model | # place the model | ||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
if self.device == 'cuda': | |||||
self.device = 'cuda:{}'.format(int(os.environ.get( | |||||
'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu' | |||||
if torch.cuda.is_available(): | |||||
self.clip_model.to(self.device) | self.clip_model.to(self.device) | ||||
logger.info('Use GPU for inference') | |||||
logger.info('Use GPU {} for finetuning & inference'.format( | |||||
int(os.environ.get('LOCAL_RANK', 0)))) | |||||
else: | else: | ||||
self.clip_model.float() | self.clip_model.float() | ||||
logger.info('Use CPU for inference') | |||||
def tokenize(self, | |||||
texts: Union[str, List[str]], | |||||
context_length: int = 52) -> torch.LongTensor: | |||||
""" | |||||
Returns the tokenized representation of given input string(s) | |||||
Parameters | |||||
---------- | |||||
texts : Union[str, List[str]] | |||||
An input string or a list of input strings to tokenize | |||||
context_length : int | |||||
The context length to use; all baseline models use 24 as the context length | |||||
Returns | |||||
------- | |||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||||
""" | |||||
if isinstance(texts, str): | |||||
texts = [texts] | |||||
all_tokens = [] | |||||
for text in texts: | |||||
all_tokens.append( | |||||
[self.tokenizer.vocab['[CLS]']] | |||||
+ self.tokenizer.convert_tokens_to_ids( | |||||
self.tokenizer.tokenize(text))[:context_length - 2] | |||||
+ [self.tokenizer.vocab['[SEP]']]) | |||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||||
for i, tokens in enumerate(all_tokens): | |||||
assert len(tokens) <= context_length | |||||
result[i, :len(tokens)] = torch.tensor(tokens) | |||||
return result | |||||
logger.info('Use CPU for finetuning & inference') | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
@@ -613,75 +564,36 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||||
OutputKeys.IMG_EMBEDDING: None, | OutputKeys.IMG_EMBEDDING: None, | ||||
OutputKeys.TEXT_EMBEDDING: None | OutputKeys.TEXT_EMBEDDING: None | ||||
} | } | ||||
if 'img' in input and input['img'] is not None: | |||||
image_input = input['img'] | |||||
# single image input | |||||
if isinstance(image_input, Image.Image): | |||||
image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||||
# multi images input | |||||
elif isinstance(image_input, list): | |||||
if all([isinstance(elem, Image.Image) | |||||
for elem in image_input]): | |||||
image_tensor = torch.stack( | |||||
[self.img_preprocess(elem) for elem in image_input], | |||||
dim=0) | |||||
else: | |||||
unsupported_elem_type = [ | |||||
type(elem) for elem in image_input | |||||
if not isinstance(elem, Image.Image) | |||||
][0] | |||||
raise TypeError( | |||||
f'img should be PIL.Image or List[PIL.Image], \ | |||||
but got a List containing one {unsupported_elem_type}' | |||||
) | |||||
# others | |||||
else: | |||||
raise TypeError( | |||||
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||||
) | |||||
image_tensor = image_tensor.to(self.device) | |||||
with torch.no_grad(): | |||||
mode = input.get('mode', ModeKeys.INFERENCE) | |||||
# encode the image | |||||
if 'img' in input and isinstance(input['img'], torch.Tensor): | |||||
image_tensor = input['img'].to(self.device) | |||||
if image_tensor.dim() == 5 and image_tensor.shape[1] == 1: | |||||
image_tensor = image_tensor.squeeze(1) | |||||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||||
image_features = self.clip_model.encode_image(image_tensor) | image_features = self.clip_model.encode_image(image_tensor) | ||||
image_features /= image_features.norm( | image_features /= image_features.norm( | ||||
dim=-1, keepdim=True) # l2-normalize | dim=-1, keepdim=True) # l2-normalize | ||||
output[OutputKeys.IMG_EMBEDDING] = image_features | output[OutputKeys.IMG_EMBEDDING] = image_features | ||||
if 'text' in input and input['text'] is not None: | |||||
text_input = input['text'] | |||||
# single text input | |||||
if isinstance(text_input, str): | |||||
text_tensor = self.tokenize(text_input) | |||||
# multi texts input | |||||
elif isinstance(text_input, list): | |||||
if all([isinstance(elem, str) for elem in text_input]): | |||||
text_tensor = self.tokenize(text_input) | |||||
else: | |||||
unsupported_elem_type = [ | |||||
type(elem) for elem in text_input | |||||
if not isinstance(elem, str) | |||||
][0] | |||||
raise TypeError( | |||||
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||||
) | |||||
# others | |||||
else: | |||||
raise TypeError( | |||||
f'text should be str or List[str], but got {type(text_input)}' | |||||
) | |||||
text_tensor = text_tensor.to(self.device) | |||||
with torch.no_grad(): | |||||
if 'text' in input and isinstance(input['text'], torch.Tensor): | |||||
text_tensor = input['text'].to(self.device) | |||||
if text_tensor.dim() == 3 and text_tensor.shape[1] == 1: | |||||
text_tensor = text_tensor.squeeze(1) | |||||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||||
text_features = self.clip_model.encode_text(text_tensor) | text_features = self.clip_model.encode_text(text_tensor) | ||||
text_features /= text_features.norm( | text_features /= text_features.norm( | ||||
dim=-1, keepdim=True) # l2-normalize | dim=-1, keepdim=True) # l2-normalize | ||||
output[OutputKeys.TEXT_EMBEDDING] = text_features | output[OutputKeys.TEXT_EMBEDDING] = text_features | ||||
if mode == ModeKeys.TRAIN: | |||||
output['logit_scale'] = (self.clip_model.logit_scale | |||||
* 1.0).exp().mean() | |||||
return output | return output | ||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
@@ -1,10 +1,12 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Any, Dict | |||||
from typing import Any, Dict, Optional, Union | |||||
from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding | |||||
from modelscope.pipelines.base import Input, Model, Pipeline | from modelscope.pipelines.base import Input, Model, Pipeline | ||||
from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor | |||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
@@ -17,7 +19,10 @@ logger = get_logger() | |||||
Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | ||||
class MultiModalEmbeddingPipeline(Pipeline): | class MultiModalEmbeddingPipeline(Pipeline): | ||||
def __init__(self, model: str, device: str = 'gpu'): | |||||
def __init__(self, | |||||
model: Union[Model, str], | |||||
preprocessor: Optional[Preprocessor] = None, | |||||
**kwargs): | |||||
""" | """ | ||||
use `model` and `preprocessor` to create a kws pipeline for prediction | use `model` and `preprocessor` to create a kws pipeline for prediction | ||||
Args: | Args: | ||||
@@ -29,14 +34,17 @@ class MultiModalEmbeddingPipeline(Pipeline): | |||||
pipe_model = model | pipe_model = model | ||||
else: | else: | ||||
raise NotImplementedError('model must be a single str') | raise NotImplementedError('model must be a single str') | ||||
pipe_model.eval() | |||||
if preprocessor is None: | |||||
if isinstance(pipe_model, CLIPForMultiModalEmbedding): | |||||
preprocessor = CLIPPreprocessor(pipe_model.model_dir) | |||||
else: | |||||
raise NotImplementedError | |||||
super().__init__(model=pipe_model) | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
return input | |||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
return self.model(input) | |||||
return self.model(self.preprocess(input)) | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
return inputs | return inputs |
@@ -3,8 +3,11 @@ import os.path as osp | |||||
from io import BytesIO | from io import BytesIO | ||||
from typing import Any, Dict, List, Tuple, Union | from typing import Any, Dict, List, Tuple, Union | ||||
import json | |||||
import torch | import torch | ||||
from PIL import Image | from PIL import Image | ||||
from timm.data import create_transform | |||||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
@@ -107,6 +110,180 @@ class OfaPreprocessor(Preprocessor): | |||||
eos_idx=self.tokenizer.eos_token_id) | eos_idx=self.tokenizer.eos_token_id) | ||||
def _convert_to_rgb(image): | |||||
return image.convert('RGB') | |||||
@PREPROCESSORS.register_module( | |||||
Fields.multi_modal, module_name=Preprocessors.clip_preprocessor) | |||||
class CLIPPreprocessor(Preprocessor): | |||||
def __init__(self, | |||||
model_dir: str, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | |||||
Args: | |||||
model_dir (str): model path | |||||
mode: preprocessor mode (model mode) | |||||
""" | |||||
super().__init__(*args, **kwargs) | |||||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||||
model_dir) | |||||
self.mode = mode | |||||
# text tokenizer | |||||
from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer | |||||
if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'], | |||||
FullTokenizer): | |||||
self.tokenizer = kwargs['tokenizer'] | |||||
else: | |||||
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | |||||
self.tokenizer = FullTokenizer(vocab_file=vocab_file) | |||||
# image preprocessor | |||||
if 'resolution' in kwargs and isinstance(kwargs['resolution'], int): | |||||
self.image_resolution = kwargs['resolution'] | |||||
else: | |||||
self.image_resolution = json.load( | |||||
open('{}/vision_model_config.json'.format( | |||||
model_dir)))['image_resolution'] | |||||
self.img_preprocess = self._build_image_transform() | |||||
# key mapping | |||||
# specify the input keys, compatible with training and inference whose key names may be different | |||||
self.input_keys = {'img': 'img', 'text': 'text'} | |||||
def _build_image_transform(self): | |||||
if self.mode == ModeKeys.TRAIN: | |||||
transform = create_transform( | |||||
input_size=self.image_resolution, | |||||
scale=(0.9, 1.0), | |||||
is_training=True, | |||||
color_jitter=None, | |||||
auto_augment='original', | |||||
interpolation='bicubic', | |||||
mean=(0.48145466, 0.4578275, 0.40821073), | |||||
std=(0.26862954, 0.26130258, 0.27577711), | |||||
) | |||||
transform = Compose(transform.transforms[:-3] + [_convert_to_rgb] | |||||
+ transform.transforms[-3:]) | |||||
else: | |||||
transform = Compose([ | |||||
Resize((self.image_resolution, self.image_resolution), | |||||
interpolation=Image.BICUBIC), | |||||
_convert_to_rgb, | |||||
ToTensor(), | |||||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711)), | |||||
]) | |||||
return transform | |||||
def tokenize(self, | |||||
texts: Union[str, List[str]], | |||||
context_length: int = 52) -> torch.LongTensor: | |||||
""" | |||||
Returns the tokenized representation of given input string(s) | |||||
Parameters | |||||
---------- | |||||
texts : Union[str, List[str]] | |||||
An input string or a list of input strings to tokenize | |||||
context_length : int | |||||
The context length to use; all baseline models use 24 as the context length | |||||
Returns | |||||
------- | |||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||||
""" | |||||
if isinstance(texts, str): | |||||
texts = [texts] | |||||
all_tokens = [] | |||||
for text in texts: | |||||
all_tokens.append( | |||||
[self.tokenizer.vocab['[CLS]']] | |||||
+ self.tokenizer.convert_tokens_to_ids( | |||||
self.tokenizer.tokenize(text))[:context_length - 2] | |||||
+ [self.tokenizer.vocab['[SEP]']]) | |||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||||
for i, tokens in enumerate(all_tokens): | |||||
assert len(tokens) <= context_length | |||||
result[i, :len(tokens)] = torch.tensor(tokens) | |||||
return result | |||||
def set_input_img_key(self, new_key: str): | |||||
self.input_keys['img'] = new_key | |||||
def set_input_text_key(self, new_key: str): | |||||
self.input_keys['text'] = new_key | |||||
def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||||
**kwargs) -> Dict[str, Any]: | |||||
output = {} | |||||
# preprocess the image input | |||||
input_img_key = self.input_keys['img'] | |||||
if input_img_key in input and input[input_img_key] is not None: | |||||
image_input = input[input_img_key] | |||||
# single image input | |||||
if isinstance(image_input, Image.Image): | |||||
image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||||
# multi images input | |||||
elif isinstance(image_input, list): | |||||
if all([isinstance(elem, Image.Image) | |||||
for elem in image_input]): | |||||
image_tensor = torch.stack( | |||||
[self.img_preprocess(elem) | |||||
for elem in image_input], # noqa | |||||
dim=0) # noqa | |||||
else: | |||||
unsupported_elem_type = [ | |||||
type(elem) for elem in image_input | |||||
if not isinstance(elem, Image.Image) | |||||
][0] | |||||
raise TypeError( | |||||
f'img should be PIL.Image or List[PIL.Image], \ | |||||
but got a List containing one {unsupported_elem_type}' | |||||
) | |||||
# others | |||||
else: | |||||
raise TypeError( | |||||
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||||
) | |||||
output['img'] = image_tensor | |||||
# preprocess the text input | |||||
input_text_key = self.input_keys['text'] | |||||
if input_text_key in input and input[input_text_key] is not None: | |||||
text_input = input[input_text_key] | |||||
# single text input | |||||
if isinstance(text_input, str): | |||||
text_tensor = self.tokenize(text_input) | |||||
# multi texts input | |||||
elif isinstance(text_input, list): | |||||
if all([isinstance(elem, str) for elem in text_input]): | |||||
text_tensor = self.tokenize(text_input) | |||||
else: | |||||
unsupported_elem_type = [ | |||||
type(elem) for elem in text_input | |||||
if not isinstance(elem, str) | |||||
][0] | |||||
raise TypeError( | |||||
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||||
) | |||||
# others | |||||
else: | |||||
raise TypeError( | |||||
f'text should be str or List[str], but got {type(text_input)}' | |||||
) | |||||
output['text'] = text_tensor | |||||
return output | |||||
@PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | ||||
class MPlugPreprocessor(Preprocessor): | class MPlugPreprocessor(Preprocessor): | ||||
@@ -0,0 +1,18 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | |||||
from modelscope.metainfo import Hooks | |||||
from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer | |||||
from .builder import HOOKS | |||||
from .hook import Hook | |||||
@HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook) | |||||
class ClipClampLogitScaleHook(Hook): | |||||
"""ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update""" | |||||
def after_train_iter(self, trainer: CLIPTrainer): | |||||
"""Called after every training iter to evaluate the results.""" | |||||
unwrapped_model = getattr(trainer.model, 'module', trainer.model) | |||||
logit_scale = unwrapped_model.clip_model.logit_scale | |||||
logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052) |
@@ -1,169 +1,206 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import math | |||||
import os | import os | ||||
from typing import Dict, Optional | |||||
from typing import Callable, Dict, Optional, Tuple, Union | |||||
import torch | import torch | ||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader | |||||
from torch.utils.data.distributed import DistributedSampler | |||||
from torch import distributed as dist | |||||
from torch import nn | |||||
from torch.utils.data import Dataset | |||||
from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
from modelscope.models.base import Model | |||||
from modelscope.trainers.base import BaseTrainer | |||||
from modelscope.models.base import Model, TorchModel | |||||
from modelscope.models.multi_modal.clip.model import convert_models_to_fp32 | |||||
from modelscope.msdatasets.ms_dataset import MsDataset | |||||
from modelscope.preprocessors.base import Preprocessor | |||||
from modelscope.preprocessors.multi_modal import CLIPPreprocessor | |||||
from modelscope.trainers import EpochBasedTrainer | |||||
from modelscope.trainers.builder import TRAINERS | from modelscope.trainers.builder import TRAINERS | ||||
from modelscope.trainers.optimizer.builder import build_optimizer | |||||
from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
from modelscope.utils.constant import ModeKeys | |||||
from modelscope.utils.logger import get_logger | |||||
from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer | |||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | |||||
ModeKeys) | |||||
from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule | |||||
logger = get_logger() | |||||
def exclude(n): | |||||
return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n | |||||
def include(n): | |||||
return not exclude(n) | |||||
@TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | @TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | ||||
class CLIPTrainer(BaseTrainer): | |||||
def __init__(self, cfg_file: str, model: str, device_id: int, *args, | |||||
**kwargs): | |||||
super().__init__(cfg_file) | |||||
self.cfg = Config.from_file(cfg_file) | |||||
self.model = Model.from_pretrained(model) | |||||
self.device_id = device_id | |||||
self.total_epoch = self.cfg.train.epoch | |||||
self.train_batch_size = self.cfg.train.batch_size | |||||
self.val_batch_size = self.cfg.evaluation.batch_size | |||||
self.ckpt_dir = self.cfg.train.ckpt_dir | |||||
self.train_dataset = ImageWithCaptionDataset( | |||||
json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||||
self.cfg.dataset.train_set), | |||||
img_dir=self.cfg.dataset.root_dir, | |||||
phase=ModeKeys.TRAIN) | |||||
self.val_dataset = ImageWithCaptionDataset( | |||||
json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||||
self.cfg.dataset.val_set), | |||||
img_dir=self.cfg.dataset.root_dir, | |||||
phase=ModeKeys.EVAL) | |||||
def train(self, *args, **kwargs): | |||||
assert dist.is_initialized() | |||||
self.model.clip_model.train() | |||||
self.model.clip_model.to(self.device_id) | |||||
ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
self.model.clip_model, device_ids=[ | |||||
self.device_id, | |||||
]) | |||||
optimizer = get_optimizer(ddp_model) | |||||
for epoch in range(self.total_epoch): | |||||
train_sampler = DistributedSampler( | |||||
dataset=self.train_dataset, shuffle=True) | |||||
train_sampler.set_epoch(epoch) | |||||
train_params = { | |||||
'pin_memory': True, | |||||
'collate_fn': None, | |||||
'batch_size': self.train_batch_size, | |||||
'shuffle': False, | |||||
'drop_last': True, | |||||
'sampler': train_sampler, | |||||
'num_workers': 8 | |||||
class CLIPTrainer(EpochBasedTrainer): | |||||
def __init__( | |||||
self, | |||||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||||
cfg_file: Optional[str] = None, | |||||
arg_parse_fn: Optional[Callable] = None, | |||||
data_collator: Optional[Union[Callable, Dict[str, | |||||
Callable]]] = None, | |||||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
preprocessor: Optional[Union[Preprocessor, | |||||
Dict[str, Preprocessor]]] = None, | |||||
optimizers: Tuple[torch.optim.Optimizer, | |||||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||||
None), | |||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
seed: int = 42, | |||||
**kwargs): | |||||
model = Model.from_pretrained(model, revision=model_revision) | |||||
# for training & eval, we convert the model from FP16 back to FP32 | |||||
# to compatible with modelscope amp training | |||||
convert_models_to_fp32(model) | |||||
cfg = Config.from_file(cfg_file) | |||||
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | |||||
work_dir = cfg.train.work_dir | |||||
else: | |||||
work_dir = kwargs['work_dir'] | |||||
# fetch the model name of CLIP model (base, large or large-336) | |||||
model_name = cfg.pretrained_model.model_name | |||||
# world size | |||||
world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||||
# train step, optimizer and lr_scheduler | |||||
epoch_steps = math.ceil( | |||||
len(train_dataset) / # noqa | |||||
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||||
if optimizers[0] is None: | |||||
named_parameters = list(model.named_parameters()) | |||||
gain_or_bias_params = [ | |||||
p for n, p in named_parameters | |||||
if exclude(n) and p.requires_grad | |||||
] | |||||
rest_params = [ | |||||
p for n, p in named_parameters | |||||
if include(n) and p.requires_grad | |||||
] | |||||
optimizer_hparams = get_optimizer_params( | |||||
model_name, cfg) # lr, wd, beta1, beta2, eps | |||||
optimizer_args = { | |||||
'params': [ | |||||
{ | |||||
'params': gain_or_bias_params, | |||||
'weight_decay': 0. | |||||
}, | |||||
{ | |||||
'params': rest_params, | |||||
'weight_decay': optimizer_hparams['weight_decay'] | |||||
}, | |||||
], | |||||
'lr': | |||||
optimizer_hparams['lr'], | |||||
'betas': | |||||
(optimizer_hparams['beta1'], optimizer_hparams['beta2']), | |||||
'eps': | |||||
optimizer_hparams['eps'], | |||||
} | |||||
optimizer = build_optimizer( | |||||
model, cfg=cfg.train.optimizer, default_args=optimizer_args) | |||||
else: | |||||
optimizer = optimizers[0] | |||||
if optimizers[1] is None: | |||||
lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler) | |||||
else: | |||||
lr_scheduler = optimizers[1] | |||||
optimizers = (optimizer, lr_scheduler) | |||||
# loss module | |||||
loss_img = nn.CrossEntropyLoss() | |||||
loss_txt = nn.CrossEntropyLoss() | |||||
self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||||
self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||||
self.loss_cfg = cfg.train.loss_cfg | |||||
# launcher and use_fp16 | |||||
if 'launcher' not in kwargs and cfg.train.get('launcher', None): | |||||
kwargs['launcher'] = cfg.train.launcher | |||||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||||
kwargs['use_fp16'] = cfg.train.use_fp16 | |||||
# preprocessor | |||||
if preprocessor is None: | |||||
preprocessor = { | |||||
ConfigKeys.train: | |||||
CLIPPreprocessor( | |||||
model_dir=work_dir, | |||||
mode=ModeKeys.TRAIN, | |||||
tokenizer=model.tokenizer, | |||||
resolution=model.model_info['image_resolution']), | |||||
ConfigKeys.val: | |||||
CLIPPreprocessor( | |||||
model_dir=work_dir, | |||||
mode=ModeKeys.EVAL, | |||||
tokenizer=model.tokenizer, | |||||
resolution=model.model_info['image_resolution']), | |||||
} | } | ||||
train_loader = DataLoader(self.train_dataset, **train_params) | |||||
for batch_idx, (img_tensor, text_str_list, | |||||
img_id_list) in enumerate(train_loader): | |||||
text_info_list = [ | |||||
self.model.tokenize_text(tmp) for tmp in text_str_list | |||||
] | |||||
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||||
dim=0) | |||||
text_masks_tensor = torch.cat( | |||||
[tmp[1] for tmp in text_info_list], dim=0) | |||||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||||
text_ids_tensor = text_ids_tensor.to( | |||||
self.device_id, non_blocking=True) | |||||
text_masks_tensor = text_masks_tensor.to( | |||||
self.device_id, non_blocking=True) | |||||
loss = ddp_model((img_tensor, text_ids_tensor, | |||||
text_masks_tensor, img_id_list), | |||||
ModeKeys.TRAIN) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
if batch_idx % 10 == 0: | |||||
logger.info( | |||||
'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' | |||||
.format(epoch, batch_idx, len(train_loader), | |||||
loss.item(), | |||||
ddp_model.module.logit_scale.exp().item())) | |||||
if dist.get_rank() == 0: | |||||
os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
torch.save(ddp_model.module.state_dict(), | |||||
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||||
def evaluate(self, | |||||
checkpoint_path: Optional[str] = None, | |||||
*args, | |||||
**kwargs) -> Dict[str, float]: | |||||
if checkpoint_path is not None: | |||||
checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||||
self.model.clip_model.load_state_dict(checkpoint_params) | |||||
self.model.clip_model.eval() | |||||
self.model.clip_model.to(self.device_id) | |||||
val_params = { | |||||
'collate_fn': None, | |||||
'batch_size': self.val_batch_size, | |||||
'shuffle': False, | |||||
'drop_last': False, | |||||
'num_workers': 8 | |||||
} | |||||
val_loader = DataLoader(self.val_dataset, **val_params) | |||||
tp_cnt_per_batch = [] | |||||
processed_cnt = 0 | |||||
with torch.no_grad(): | |||||
for batch_idx, (img_tensor, text_str_list, | |||||
img_id_list) in enumerate(val_loader): | |||||
text_info_list = [ | |||||
self.model.tokenize_text(tmp) for tmp in text_str_list | |||||
] | |||||
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||||
dim=0) | |||||
text_masks_tensor = torch.cat( | |||||
[tmp[1] for tmp in text_info_list], dim=0) | |||||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||||
text_ids_tensor = text_ids_tensor.to( | |||||
self.device_id, non_blocking=True) | |||||
text_masks_tensor = text_masks_tensor.to( | |||||
self.device_id, non_blocking=True) | |||||
img_feat = self.model.clip_model(img_tensor, input_type='img') | |||||
text_feat = self.model.clip_model( | |||||
(text_ids_tensor, text_masks_tensor), input_type='text') | |||||
sim_mat = text_feat @ img_feat.t() | |||||
text_cnt, img_cnt = sim_mat.shape | |||||
top1_scores, match_ids = torch.max(sim_mat, dim=1) | |||||
match_ids = match_ids.int() | |||||
gt_ids = torch.tensor(range(0, text_cnt)).to( | |||||
self.device_id, non_blocking=True).int() | |||||
error_cnt = torch.nonzero(match_ids - gt_ids) | |||||
processed_cnt += text_cnt | |||||
tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) | |||||
logger.info('current acc: {:.3f}'.format( | |||||
sum(tp_cnt_per_batch) / processed_cnt)) | |||||
# dataset related | |||||
self.dataset_cfg = cfg.dataset | |||||
if hasattr(self.dataset_cfg, 'column_map'): | |||||
# cases where dataset key names are not "img" and "text" | |||||
img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img') | |||||
preprocessor[ConfigKeys.train].set_input_img_key(img_key_name) | |||||
preprocessor[ConfigKeys.val].set_input_img_key(img_key_name) | |||||
text_key_name = getattr(self.dataset_cfg.column_map, 'text', | |||||
'text') | |||||
preprocessor[ConfigKeys.train].set_input_text_key(text_key_name) | |||||
preprocessor[ConfigKeys.val].set_input_text_key(text_key_name) | |||||
self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size | |||||
super().__init__( | |||||
model=model, | |||||
cfg_file=cfg_file, | |||||
arg_parse_fn=arg_parse_fn, | |||||
data_collator=data_collator, | |||||
train_dataset=train_dataset, | |||||
eval_dataset=eval_dataset, | |||||
preprocessor=preprocessor, | |||||
optimizers=optimizers, | |||||
seed=seed, | |||||
**kwargs, | |||||
) | |||||
def train_step(self, model, inputs): | |||||
model.train() | |||||
inputs['mode'] = ModeKeys.TRAIN | |||||
model_outputs = model.forward( | |||||
inputs | |||||
) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)} | |||||
loss = get_loss(model_outputs, self.loss_img, self.loss_txt, | |||||
self.loss_cfg) | |||||
train_outputs = {'loss': loss} | |||||
# add model output info to log | |||||
if 'log_vars' not in train_outputs: | |||||
default_keys_pattern = ['loss'] | |||||
match_keys = set([]) | |||||
for key_p in default_keys_pattern: | |||||
match_keys.update( | |||||
[key for key in train_outputs.keys() if key_p in key]) | |||||
log_vars = {} | |||||
for key in match_keys: | |||||
value = train_outputs.get(key, None) | |||||
if value is not None: | |||||
if dist.is_available() and dist.is_initialized(): | |||||
value = value.data.clone() | |||||
dist.all_reduce(value.div_(dist.get_world_size())) | |||||
log_vars.update({key: value.item()}) | |||||
unwrapped_model = getattr(model, 'module', model) | |||||
log_vars[ | |||||
'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone( | |||||
).item() # noqa | |||||
log_vars['global_batch_size'] = int(self.global_batch_size) | |||||
self.log_buffer.update(log_vars) | |||||
else: | |||||
self.log_buffer.update(train_outputs['log_vars']) | |||||
self.train_outputs = train_outputs |
@@ -1,94 +1,125 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Copyright 2022 The OFA-Sys Team. | |||||
# All rights reserved. | |||||
# This source code is licensed under the Apache 2.0 license | |||||
# found in the LICENSE file in the root directory. | |||||
import math | |||||
import os | import os | ||||
import random | |||||
from functools import partial | |||||
from inspect import unwrap | |||||
import json | |||||
import torch | import torch | ||||
import torch.nn.functional as F | |||||
from PIL import Image | |||||
from torch.utils.data import Dataset | |||||
from torchvision import transforms | |||||
from modelscope.utils.constant import ModeKeys | |||||
train_transform = transforms.Compose([ | |||||
transforms.RandomResizedCrop( | |||||
224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |||||
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |||||
p=0.8), | |||||
transforms.RandomGrayscale(p=0.2), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711)) | |||||
]) | |||||
val_transform = transforms.Compose([ | |||||
transforms.Resize((224, 224), interpolation=Image.BICUBIC), | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711)) | |||||
]) | |||||
class ImageWithCaptionDataset(Dataset): | |||||
def __init__(self, json_file, img_dir, phase): | |||||
self.annotations = json.load(open(json_file)) | |||||
self.img_dir = img_dir | |||||
if phase == ModeKeys.TRAIN: | |||||
self.transform = train_transform | |||||
elif phase == ModeKeys.EVAL: | |||||
self.transform = val_transform | |||||
self.img_name2img_id = {} | |||||
for anno_dict in self.annotations: | |||||
img_name = anno_dict['image'] | |||||
if img_name not in self.img_name2img_id: | |||||
self.img_name2img_id[img_name] = len(self.img_name2img_id) | |||||
def __len__(self): | |||||
return len(self.annotations) | |||||
def __getitem__(self, index): | |||||
anno_dict = self.annotations[index] | |||||
img_path = os.path.join(self.img_dir, anno_dict['image']) | |||||
img_pil = Image.open(img_path).convert('RGB') | |||||
img_th = self.transform(img_pil) | |||||
img_id = self.img_name2img_id[anno_dict['image']] | |||||
text_str = random.choice(anno_dict['caption']) | |||||
return img_th, text_str, img_id | |||||
def get_params_groups(ddp_model, weight_decay): | |||||
decay = [] | |||||
no_decay = [] | |||||
for name, param in ddp_model.named_parameters(): | |||||
if not param.requires_grad: | |||||
continue | |||||
if len(param.shape) == 1 or name.endswith('.bias'): | |||||
no_decay.append(param) | |||||
else: | |||||
decay.append(param) | |||||
params_groups = [{ | |||||
'params': no_decay, | |||||
'weight_decay': 0. | |||||
}, { | |||||
'params': decay, | |||||
'weight_decay': weight_decay | |||||
}] | |||||
return params_groups | |||||
def get_optimizer(ddp_model): | |||||
from torch.optim import AdamW | |||||
lr_init = 1e-5 | |||||
betas = [0.9, 0.999] | |||||
weight_decay = 0.02 | |||||
params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) | |||||
return AdamW( | |||||
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||||
import torch.distributed as dist | |||||
from torch.optim.lr_scheduler import LambdaLR | |||||
from modelscope.outputs import OutputKeys | |||||
def get_optimizer_params(model_name, cfg): | |||||
# get default params | |||||
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf) | |||||
# base model | |||||
if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']: | |||||
params = { | |||||
'lr': 5.0e-4, | |||||
'beta1': 0.9, | |||||
'beta2': 0.98, | |||||
'eps': 1.0e-6, | |||||
'weight_decay': 0.0 | |||||
} | |||||
# large models | |||||
elif model_name in [ | |||||
'damo/multi-modal_clip-vit-large-patch14_zh', | |||||
'damo/multi-modal_clip-vit-large-patch14_336_zh' | |||||
]: | |||||
params = { | |||||
'lr': 4.0e-4, | |||||
'beta1': 0.9, | |||||
'beta2': 0.98, | |||||
'eps': 1.0e-6, | |||||
'weight_decay': 0.0 | |||||
} | |||||
else: | |||||
params = { | |||||
'lr': 5.0e-4, | |||||
'beta1': 0.9, | |||||
'beta2': 0.999, | |||||
'eps': 1.0e-8, | |||||
'weight_decay': 0.0 | |||||
} | |||||
# override with config params | |||||
for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']: | |||||
if hasattr(cfg.train, 'optimizer_hparams'): | |||||
params[key] = getattr(cfg.train.optimizer_hparams, key, | |||||
params[key]) | |||||
return params | |||||
def get_loss(model_outputs, loss_img, loss_txt, loss_cfg): | |||||
image_features = model_outputs[OutputKeys.IMG_EMBEDDING] | |||||
text_features = model_outputs[OutputKeys.TEXT_EMBEDDING] | |||||
logit_scale = model_outputs['logit_scale'] | |||||
logit_scale = logit_scale.mean() | |||||
if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1: | |||||
world_size = dist.get_world_size() | |||||
rank = dist.get_rank() | |||||
# We gather tensors from all gpus to get more negatives to contrast with. | |||||
gathered_image_features = [ | |||||
torch.zeros_like(image_features) for _ in range(world_size) | |||||
] | |||||
gathered_text_features = [ | |||||
torch.zeros_like(text_features) for _ in range(world_size) | |||||
] | |||||
dist.all_gather(gathered_image_features, image_features) | |||||
dist.all_gather(gathered_text_features, text_features) | |||||
all_image_features = torch.cat([image_features] | |||||
+ gathered_image_features[:rank] | |||||
+ gathered_image_features[rank + 1:]) | |||||
all_text_features = torch.cat([text_features] | |||||
+ gathered_text_features[:rank] | |||||
+ gathered_text_features[rank + 1:]) | |||||
# this is needed to send gradients back everywhere. | |||||
logits_per_image = logit_scale * all_image_features @ all_text_features.t( | |||||
) | |||||
logits_per_text = logits_per_image.t() | |||||
else: | |||||
logits_per_image = logit_scale * image_features @ text_features.t() | |||||
logits_per_text = logit_scale * text_features @ image_features.t() | |||||
ground_truth = torch.arange(len(logits_per_image)).long() | |||||
ground_truth = ground_truth.cuda( | |||||
int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True) | |||||
total_loss = (loss_img(logits_per_image, ground_truth) | |||||
+ loss_txt(logits_per_text, ground_truth)) / 2 | |||||
return total_loss | |||||
def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step): | |||||
if current_step < num_warmup_steps: | |||||
return float(current_step) / float(max(1, num_warmup_steps)) | |||||
progress = float(current_step - num_warmup_steps) / float( | |||||
max(1, num_training_steps - num_warmup_steps)) | |||||
return max( | |||||
0.0, | |||||
0.5 * # noqa | |||||
(1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa | |||||
def get_schedule(optimizer, | |||||
scheduler, | |||||
num_cycles: float = 0.5, | |||||
last_epoch: int = -1): | |||||
num_warmup_steps = int(scheduler.warmup_proportion | |||||
* scheduler.num_train_steps) | |||||
num_training_steps = scheduler.num_train_steps | |||||
return LambdaLR( | |||||
optimizer, | |||||
partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles), | |||||
last_epoch) |
@@ -24,7 +24,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
def test_run(self): | def test_run(self): | ||||
pipeline_multi_modal_embedding = pipeline( | pipeline_multi_modal_embedding = pipeline( | ||||
Tasks.multi_modal_embedding, model=self.model_id) | Tasks.multi_modal_embedding, model=self.model_id) | ||||
text_embedding = pipeline_multi_modal_embedding( | |||||
text_embedding = pipeline_multi_modal_embedding.forward( | |||||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | self.test_input)[OutputKeys.TEXT_EMBEDDING] | ||||
print('l1-norm: {}'.format( | print('l1-norm: {}'.format( | ||||
torch.norm(text_embedding, p=1, dim=-1).item())) | torch.norm(text_embedding, p=1, dim=-1).item())) | ||||
@@ -36,7 +36,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
pipeline_multi_modal_embedding = pipeline( | pipeline_multi_modal_embedding = pipeline( | ||||
task=Tasks.multi_modal_embedding, model=model) | task=Tasks.multi_modal_embedding, model=model) | ||||
text_embedding = pipeline_multi_modal_embedding( | |||||
text_embedding = pipeline_multi_modal_embedding.forward( | |||||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | self.test_input)[OutputKeys.TEXT_EMBEDDING] | ||||
print('l1-norm: {}'.format( | print('l1-norm: {}'.format( | ||||
torch.norm(text_embedding, p=1, dim=-1).item())) | torch.norm(text_embedding, p=1, dim=-1).item())) | ||||
@@ -47,7 +47,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
pipeline_multi_modal_embedding = pipeline( | pipeline_multi_modal_embedding = pipeline( | ||||
task=Tasks.multi_modal_embedding) | task=Tasks.multi_modal_embedding) | ||||
text_embedding = pipeline_multi_modal_embedding( | |||||
text_embedding = pipeline_multi_modal_embedding.forward( | |||||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | self.test_input)[OutputKeys.TEXT_EMBEDDING] | ||||
print('l1-norm: {}'.format( | print('l1-norm: {}'.format( | ||||
torch.norm(text_embedding, p=1, dim=-1).item())) | torch.norm(text_embedding, p=1, dim=-1).item())) | ||||
@@ -0,0 +1,83 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import shutil | |||||
import unittest | |||||
import json | |||||
from modelscope.metainfo import Metrics, Trainers | |||||
from modelscope.msdatasets import MsDataset | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.test_utils import test_level | |||||
class TestClipTrainer(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.finetune_cfg = \ | |||||
{'framework': 'pytorch', | |||||
'task': 'multi-modal-embedding', | |||||
'pipeline': {'type': 'multi-modal-embedding'}, | |||||
'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'}, | |||||
'dataset': {'column_map': {'img': 'image', 'text': 'query'}}, | |||||
'train': {'work_dir': './workspace/ckpts/clip', | |||||
# 'launcher': 'pytorch', | |||||
'max_epochs': 1, | |||||
'use_fp16': True, | |||||
'dataloader': {'batch_size_per_gpu': 8, | |||||
'workers_per_gpu': 0, | |||||
'shuffle': True, | |||||
'drop_last': True}, | |||||
'lr_scheduler': {'name': 'cosine', | |||||
'warmup_proportion': 0.01}, | |||||
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||||
'optimizer': {'type': 'AdamW'}, | |||||
'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01}, | |||||
'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||||
'cumulative_iters': 1, | |||||
'loss_keys': 'loss'}, | |||||
'loss_cfg': {'aggregate': True}, | |||||
'hooks': [{'type': 'BestCkptSaverHook', | |||||
'metric_key': 'inbatch_t2i_recall_at_1', | |||||
'interval': 100}, | |||||
{'type': 'TextLoggerHook', 'interval': 1}, | |||||
{'type': 'IterTimerHook'}, | |||||
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}, | |||||
{'type': 'ClipClampLogitScaleHook'}]}, | |||||
'evaluation': {'dataloader': {'batch_size_per_gpu': 8, | |||||
'workers_per_gpu': 0, | |||||
'shuffle': True, | |||||
'drop_last': True}, | |||||
'metrics': [{'type': 'inbatch_recall'}]}, | |||||
'preprocessor': []} | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_trainer_std(self): | |||||
WORKSPACE = './workspace/ckpts/clip' | |||||
os.makedirs(WORKSPACE, exist_ok=True) | |||||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||||
with open(config_file, 'w') as writer: | |||||
json.dump(self.finetune_cfg, writer) | |||||
pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh' | |||||
args = dict( | |||||
model=pretrained_model, | |||||
work_dir=WORKSPACE, | |||||
train_dataset=MsDataset.load( | |||||
'muge', namespace='modelscope', split='train[:200]'), | |||||
eval_dataset=MsDataset.load( | |||||
'muge', namespace='modelscope', split='validation[:100]'), | |||||
metrics=[Metrics.inbatch_recall], | |||||
cfg_file=config_file) | |||||
trainer = build_trainer( | |||||
name=Trainers.clip_multi_modal_embedding, default_args=args) | |||||
trainer.train() | |||||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||||
os.listdir(os.path.join(WORKSPACE, 'output'))) | |||||
shutil.rmtree(WORKSPACE) | |||||
if __name__ == '__main__': | |||||
unittest.main() |